# Breast cancer

In [32]:
import pandas as pd
import jax
import jax.numpy as jnp
from sklearn.datasets import load_breast_cancer

## Load data

In [33]:
data = load_breast_cancer()
bc_data = pd.DataFrame(data.data, columns=data.feature_names)
bc_data['target'] = data.target
bc_data['target'] = bc_data['target'].astype(int)
bc_data['diagnosis'] = bc_data['target'].map({ 0: 'malign', 1: 'benign'})

In [34]:
bc_data.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target,diagnosis
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0,malign
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0,malign
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0,malign
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0,malign
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0,malign


In [35]:
X_wo_diagnosis = bc_data.drop(['target', 'diagnosis'], axis=1).astype(float).values
y_diagnosis = bc_data['target'].astype(float).values

## Create seed and separate data

In [36]:
# creates te seed for random numbers befor spliting the data
seed = jax.random.PRNGKey(73)
index = jnp.arange(len(X_wo_diagnosis))
index_shuffle = jax.random.permutation(seed, index)

In [37]:
X_shuffle = X_wo_diagnosis[index_shuffle]
y_shuffle = y_diagnosis[index_shuffle]
split_index = int(len(X_shuffle) * 0.9) # calculate %90 of the total data

In [38]:
# separate %90 data for training
X_training = X_shuffle[:split_index]
y_training = y_shuffle[:split_index]

# data for test
X_test = X_shuffle[split_index:]
y_test = y_shuffle[split_index:]

In [39]:
X_training

array([[1.316e+01, 2.054e+01, 8.406e+01, ..., 4.195e-02, 2.687e-01,
        7.429e-02],
       [1.189e+01, 1.835e+01, 7.732e+01, ..., 1.138e-01, 3.397e-01,
        8.365e-02],
       [1.791e+01, 2.102e+01, 1.244e+02, ..., 1.964e-01, 3.245e-01,
        1.198e-01],
       ...,
       [1.499e+01, 2.211e+01, 9.753e+01, ..., 1.308e-01, 3.163e-01,
        9.251e-02],
       [1.218e+01, 1.408e+01, 7.725e+01, ..., 1.852e-02, 2.293e-01,
        6.037e-02],
       [1.025e+01, 1.618e+01, 6.652e+01, ..., 9.744e-02, 2.608e-01,
        9.702e-02]])

## Augmented Vector on training data

In [40]:
mid_vector = jnp.ones((X_training.shape[0], 1))
X_aug = jnp.concatenate([mid_vector, jnp.array(X_training)], axis=1)
y = jnp.array(y_training)
print(X_aug.shape)

(512, 31)


In [41]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv

## Estimation

In [42]:
y_hat = X_aug @ beta
y_hat

Array([ 8.28862488e-01,  8.42438281e-01,  2.52577960e-02,  8.26173723e-01,
        1.32786953e+00,  1.12005162e+00,  1.05855572e+00,  9.84911680e-01,
        9.93545890e-01,  6.69406533e-01,  8.54399383e-01,  5.46065569e-02,
        8.19324315e-01,  1.12186074e+00,  8.99714470e-01,  3.62811685e-01,
        2.38534570e-01,  7.06417561e-01,  1.07010692e-01,  7.19575524e-01,
        9.36500072e-01,  7.12437272e-01, -1.77493513e-01,  4.37846571e-01,
        8.91910315e-01,  6.46080256e-01,  6.74300969e-01, -5.51638156e-02,
        1.04399335e+00,  1.14485419e+00,  7.96178758e-01,  1.66867018e-01,
        9.75974321e-01,  1.03899467e+00,  3.32901001e-01,  1.02462196e+00,
        1.03420722e+00,  9.93962884e-02,  1.11280918e+00,  1.03312635e+00,
        1.06138778e+00,  8.68105769e-01,  1.01698124e+00,  8.82826686e-01,
       -1.72298789e-01,  7.67955422e-01, -1.84031427e-01,  2.21648455e-01,
        8.20448220e-01,  1.09110057e+00,  9.42608118e-01,  8.43752682e-01,
        1.07633471e+00,  

## Calculate error

In [43]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

The average error is: 0.17776329815387726
MSE: 0.05146130174398422
RMSE: 0.22685083746910095
R² is : 0.7770018577575684


In [44]:
y_labeled = ["benign" if val > 0.5 else "malign" for val in y_hat]
y_labeled

['benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',

## Test data

In [45]:
mid_vector_test = jnp.ones((X_test.shape[0], 1))
X_aug_test = jnp.concatenate([mid_vector_test, X_test], axis=1)
print(X_aug.shape)

(512, 31)


In [46]:
y_hat_test = X_aug_test @ beta
y_hat_test

Array([ 0.42548677,  0.45497972, -0.17376447,  0.91161203,  0.41655397,
        0.92263204,  0.86933666,  0.9924647 ,  0.04543105,  0.94417846,
       -0.04643899,  0.98597777,  0.1416974 ,  0.987213  ,  0.81024724,
        0.8740764 ,  0.70222694,  0.5655428 ,  0.2628293 ,  0.0621005 ,
        1.0111322 ,  0.73150223,  0.68480504,  0.6275369 ,  0.7651396 ,
       -0.00597739,  0.84721446,  0.7214324 , -0.44602343,  0.5742017 ,
        0.2277796 ,  0.3692384 ,  0.65716195,  0.83835953,  0.05405688,
       -0.0490433 ,  1.0100409 ,  0.8706694 ,  0.9212559 ,  0.14195749,
        0.02478746, -0.19030227,  1.1477607 ,  0.89556116,  0.15860477,
        0.42866194,  0.7360686 ,  0.8329816 ,  0.35593632,  0.76913184,
       -0.20512974,  0.19616866,  0.9754083 ,  1.082237  ,  1.1123945 ,
        1.3019018 ,  0.30131048], dtype=float32)

In [47]:
y_labeled_test = ["benign" if val > 0.5 else "malign" for val in y_hat_test]
y_labeled_test

['malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign']

## Test error

In [48]:
res_test = y_test - y_hat_test
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse_test = jnp.mean(jnp.power(res, 2))
rmse_test = jnp.sqrt(mse_test)
print(f'MSE: {mse_test}')
print(f'RMSE: {rmse_test}')
rss_test = jnp.sum(jnp.square(res_test))
y_med_test = jnp.mean(y_test)
ss_total_test = jnp.sum(jnp.square(y_test - y_med_test))
r2_test = 1 - (rss_test / ss_total_test)
print(f'R² is : {r2}')

The average error is: 0.17776329815387726
MSE: 0.05146130174398422
RMSE: 0.22685083746910095
R² is : 0.7770018577575684


## Precision and Accuracy

### Confusion matrix

In [49]:
# aplying binning, where 1 is benign and 0 malign
y_pred = jnp.where(y_hat_test > 0.5, 1, 0)

# true positive
tp = jnp.sum((y_pred == 1) & (y_test == 1))

#true negative
tn = jnp.sum((y_pred == 0) & (y_test == 0))

#false positive
fp = jnp.sum((y_pred == 1) & (y_test == 0))

#false negative
fn = jnp.sum((y_pred == 0) & (y_test == 1))

In [50]:
presicion = tp / (tp + fp)
accuracy = (tp + tn) / (tp + tn + fp + fn)
print(f'Precision: {presicion}')
print(f'Accuracy: {accuracy}')

Precision: 0.9090909361839294
Accuracy: 0.9473684430122375
