In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
%config InlineBackend.figure_format = "retina"

from sklearn.compose import make_column_transformer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline

# Data

## Read CSV

In [2]:
!cat data/breast-cancer-wisconsin.names

Citation Request:
   This breast cancer databases was obtained from the University of Wisconsin
   Hospitals, Madison from Dr. William H. Wolberg.  If you publish results
   when using this database, then please include this information in your
   acknowledgements.  Also, please cite one or more of:

   1. O. L. Mangasarian and W. H. Wolberg: "Cancer diagnosis via linear 
      programming", SIAM News, Volume 23, Number 5, September 1990, pp 1 & 18.

   2. William H. Wolberg and O.L. Mangasarian: "Multisurface method of 
      pattern separation for medical diagnosis applied to breast cytology", 
      Proceedings of the National Academy of Sciences, U.S.A., Volume 87, 
      December 1990, pp 9193-9196.

   3. O. L. Mangasarian, R. Setiono, and W.H. Wolberg: "Pattern recognition 
      via linear programming: Theory and application to medical diagnosis", 
      in: "Large-scale numerical optimization", Thomas F. Coleman and Yuying
      Li, editors, SIAM Publications,

In [3]:
df = pd.read_csv("data/breast-cancer-wisconsin.data", header=None, index_col=0, na_values="?")
df = df.dropna()  # Drop rows with nulls
df = df.replace({10: {2: 0, 4: 1}})  # Set binary target to 0/1
df

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1000025,5,1,1,1,2,1.0,3,1,1,0
1002945,5,4,4,5,7,10.0,3,2,1,0
1015425,3,1,1,1,2,2.0,3,1,1,0
1016277,6,8,8,1,3,4.0,3,7,1,0
1017023,4,1,1,3,2,1.0,3,1,1,0
...,...,...,...,...,...,...,...,...,...,...
776715,3,1,1,1,3,2.0,1,1,1,0
841769,2,1,1,1,2,1.0,1,1,1,0
888820,5,10,10,3,7,3.0,8,10,2,1
897471,4,8,6,4,3,4.0,10,6,1,1


In [4]:
df[10].value_counts()

0    444
1    239
Name: 10, dtype: int64

## Create X and y

In [5]:
X = torch.from_numpy(df.values).float()
X

tensor([[ 5.,  1.,  1.,  ...,  1.,  1.,  0.],
        [ 5.,  4.,  4.,  ...,  2.,  1.,  0.],
        [ 3.,  1.,  1.,  ...,  1.,  1.,  0.],
        ...,
        [ 5., 10., 10.,  ..., 10.,  2.,  1.],
        [ 4.,  8.,  6.,  ...,  6.,  1.,  1.],
        [ 4.,  8.,  8.,  ...,  4.,  1.,  1.]])

In [6]:
X, y = X[:, :-1], X[:, -1]
X = torch.hstack((X, torch.ones((len(X), 1))))

In [7]:
X

tensor([[ 5.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 5.,  4.,  4.,  ...,  2.,  1.,  1.],
        [ 3.,  1.,  1.,  ...,  1.,  1.,  1.],
        ...,
        [ 5., 10., 10.,  ..., 10.,  2.,  1.],
        [ 4.,  8.,  6.,  ...,  6.,  1.,  1.],
        [ 4.,  8.,  8.,  ...,  4.,  1.,  1.]])

# Training

In [8]:
torch.manual_seed(0)
w = torch.nn.Parameter(torch.randn(X.shape[1]))
w

Parameter containing:
tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380,
        -0.7193, -0.4033], requires_grad=True)

In [9]:
X @ w

tensor([ 3.1590e+00, -1.9156e+01, -1.3216e+00, -1.2858e+01,  2.7549e+00,
        -2.1051e+01, -1.5592e+01, -3.6428e+00, -5.1477e+00,  9.2125e-01,
        -1.9204e+00, -1.8673e+00, -5.2832e-01, -5.8022e+00, -1.2188e+01,
        -6.0866e+00,  1.2147e+00,  1.6180e+00, -1.6199e+01,  4.7000e+00,
        -6.0867e+00, -1.4620e+00, -3.2632e-01, -3.0050e+00, -3.9881e+00,
         4.6477e-01,  2.7557e+00, -1.8673e+00, -8.1692e+00,  7.5820e-01,
        -1.4640e+00, -1.2245e+01, -8.9554e-01, -2.5051e+00, -1.8673e+00,
        -3.0241e+00,  7.1045e+00, -7.3007e+00, -1.3623e+01,  5.5873e+00,
        -3.4225e+01, -1.2857e+01, -6.6288e+00, -4.1276e+00, -2.0139e+01,
        -3.4083e+00,  2.7549e+00, -1.5130e+01, -1.3180e+01, -1.7618e+00,
         4.9905e-01, -2.2243e+01, -1.0867e+01, -3.9335e+00, -1.5841e+01,
        -1.1384e+00, -1.7611e+01, -5.8934e-01,  1.9460e-01, -4.8069e+00,
        -2.9717e+01, -3.7116e-01, -3.4083e+00,  9.2789e-01,  1.6180e+00,
        -1.5951e+01, -1.6591e+01, -2.1669e+00, -1.6