# Breast cancer

In [1]:
import pandas as pd
import jax
import jax.numpy as jnp
from sklearn.datasets import load_breast_cancer

## Load data

In [16]:
data = load_breast_cancer()
bc_data = pd.DataFrame(data.data, columns=data.feature_names)
bc_data['target'] = data.target
bc_data['target'] = bc_data['target'].astype(int)
bc_data['diagnosis'] = bc_data['target'].map({ 0: 'malign', 1: 'benign'})

In [17]:
bc_data.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target,diagnosis
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0,malign
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0,malign
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0,malign
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0,malign
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0,malign


In [18]:
X_wo_diagnosis = bc_data.drop(['target', 'diagnosis'], axis=1).astype(float).values
y_diagnosis = bc_data['target'].astype(float).values


## Create seed and separate data

In [19]:
# creates te seed for random numbers befor spliting the data
seed = jax.random.PRNGKey(73)
index = jnp.arange(len(X_wo_diagnosis))
index_shuffle = jax.random.permutation(seed, index)

In [20]:
X_shuffle = X_wo_diagnosis[index_shuffle]
y_shuffle = y_diagnosis[index_shuffle]
split_index = int(len(X_shuffle) * 0.9) # calculate %90 of the total data

In [21]:
# separate %90 data for training
X_training = X_shuffle[:split_index]
y_training = y_shuffle[:split_index]

# data for test
X_test = X_shuffle[split_index:]
y_test = y_shuffle[split_index:]

In [22]:
X_training

array([[1.316e+01, 2.054e+01, 8.406e+01, ..., 4.195e-02, 2.687e-01,
        7.429e-02],
       [1.189e+01, 1.835e+01, 7.732e+01, ..., 1.138e-01, 3.397e-01,
        8.365e-02],
       [1.791e+01, 2.102e+01, 1.244e+02, ..., 1.964e-01, 3.245e-01,
        1.198e-01],
       ...,
       [1.499e+01, 2.211e+01, 9.753e+01, ..., 1.308e-01, 3.163e-01,
        9.251e-02],
       [1.218e+01, 1.408e+01, 7.725e+01, ..., 1.852e-02, 2.293e-01,
        6.037e-02],
       [1.025e+01, 1.618e+01, 6.652e+01, ..., 9.744e-02, 2.608e-01,
        9.702e-02]], shape=(512, 30))

## Augmented Vector on training data

In [23]:
mid_vector = jnp.ones((X_training.shape[0], 1))
X_aug = jnp.concatenate([mid_vector, jnp.array(X_training)], axis=1)
y = jnp.array(y_training)
print(X_aug.shape)

(512, 31)


In [24]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv

In [25]:
# beta

## Estimation

In [26]:
y_hat = X_aug @ beta
y_hat

Array([ 0.82844055,  0.8429376 ,  0.02284688,  0.8262625 ,  1.3274504 ,
        1.1208482 ,  1.0590539 ,  0.98465073,  0.9930228 ,  0.6694769 ,
        0.85400534,  0.05246198,  0.8209833 ,  1.120676  ,  0.899624  ,
        0.36353505,  0.23890275,  0.7065736 ,  0.10731661,  0.72036433,
        0.93616056,  0.71322906, -0.176146  ,  0.4383917 ,  0.89278895,
        0.6465437 ,  0.67445755, -0.05486581,  1.04343   ,  1.1448171 ,
        0.79571164,  0.16710466,  0.9760175 ,  1.0401738 ,  0.33235234,
        1.0243273 ,  1.0344948 ,  0.09863782,  1.1117816 ,  1.0327653 ,
        1.0611835 ,  0.8683651 ,  1.0174767 ,  0.882746  , -0.17078933,
        0.7681125 , -0.185377  ,  0.2215352 ,  0.82078195,  1.0899382 ,
        0.9422432 ,  0.8436861 ,  1.0760268 ,  0.26926082,  0.8747866 ,
        0.98797965,  0.4854245 ,  0.7396861 ,  0.95997447,  0.6832626 ,
        0.26098517, -0.23445717,  0.3155302 ,  0.15713772,  1.203392  ,
        1.002231  ,  0.5928813 ,  0.5097378 ,  0.9032904 ,  0.25

## Calculate error

In [27]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

The average error is: 0.17772424221038818
MSE: 0.05146179720759392
RMSE: 0.2268519252538681
R² is : 0.7769997119903564


In [28]:
y_labeled = ["benign" if val > 0.5 else "malign" for val in y_hat]

In [29]:
y_labeled

['benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',

## Test data

In [30]:
mid_vector = jnp.ones((X_test.shape[0], 1))
X_aug = jnp.concatenate([mid_vector, jnp.array(X_test)], axis=1)
y = jnp.array(y_test)
print(X_aug.shape)

(57, 31)


In [31]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv
y_hat = X_aug @ beta
y_hat

Array([ 1.01814270e-01,  2.75936425e-01,  5.84830642e-02,  9.73012567e-01,
        2.52860904e-01,  1.18215895e+00,  1.07472634e+00,  1.09981275e+00,
       -1.00225210e-02,  1.10297942e+00, -1.11077607e-01,  1.03498721e+00,
        5.17105460e-02,  8.51636767e-01,  9.64229405e-01,  1.06419969e+00,
        8.80276322e-01,  1.34629250e-01,  1.00183487e-01,  1.05677366e-01,
        1.06290317e+00,  9.92781460e-01,  6.64166451e-01, -7.29326010e-02,
        8.41149211e-01, -7.24412203e-02,  9.78482127e-01,  7.30124533e-01,
       -1.33021235e-01,  8.36759150e-01,  3.03998351e-01,  3.45954895e-02,
        3.46915573e-01,  6.83600903e-01,  2.31331587e-02,  1.41391635e-01,
        1.13254309e+00,  9.80163574e-01,  6.55666709e-01,  1.01782620e-01,
       -1.60896122e-01, -2.70078957e-01,  1.19922972e+00,  9.51718152e-01,
        1.82604194e-01,  1.44466579e-01,  8.65396917e-01,  8.75677228e-01,
       -3.10599804e-04,  7.52910733e-01, -1.89059794e-01, -8.81506205e-02,
        1.02304208e+00,  

In [32]:
y_labeled = ["benign" if val > 0.5 else "malign" for val in y_hat]
y_labeled

['malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign']

## Test error

In [33]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

The average error is: 0.13349540531635284
MSE: 0.0266413614153862
RMSE: 0.1632218062877655
R² is : 0.8931385278701782
