In [74]:
import pandas as pd
import jax
import jax.numpy as jnp
from sklearn.datasets import load_breast_cancer

In [75]:
data = load_breast_cancer()
bc_data = pd.DataFrame(data.data, columns=data.feature_names)
bc_data['target'] = data.target
bc_data['target'] = bc_data['target'].astype(int)
bc_data['diagnosis'] = bc_data['target'].map({ 0: 'malign', 1: 'benign'})

In [76]:
bc_data.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,target,diagnosis
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0,malign
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0,malign
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0,malign
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0,malign
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0,malign


In [77]:
X_wo_diagnosis = bc_data.drop(['target', 'diagnosis'], axis=1).astype(float).values
y_diagnosis = bc_data['target'].astype(float).values


In [110]:
# creates te seed for random numbers befor spliting the data
seed = jax.random.PRNGKey(73)
index = jnp.arange(len(X_wo_diagnosis))
index_shuffle = jax.random.permutation(seed, index)

In [111]:
X_shuffle = X_wo_diagnosis[index_shuffle]
y_shuffle = y_diagnosis[index_shuffle]

In [None]:
# separate %90 data for training
len(X_wo_diagnosis)
X_training = jnp.array([])
y_training = jnp.array([])
for i in range(512):
    X_training[i] = X_wo_diagnosis[i].copy()
    #y_training = y_training[i].copy()

IndexError: index is out of bounds for axis 0 with size 0

In [103]:
X_training

array([1.481e+01, 1.470e+01, 9.466e+01, 6.807e+02, 8.472e-02, 5.016e-02,
       3.416e-02, 2.541e-02, 1.659e-01, 5.348e-02, 2.182e-01, 6.232e-01,
       1.677e+00, 2.072e+01, 6.708e-03, 1.197e-02, 1.482e-02, 1.056e-02,
       1.580e-02, 1.779e-03, 1.561e+01, 1.758e+01, 1.017e+02, 7.602e+02,
       1.139e-01, 1.011e-01, 1.101e-01, 7.955e-02, 2.334e-01, 6.142e-02])

In [78]:
mid_vector = jnp.ones((X_wo_diagnosis.shape[0], 1))
X_aug = jnp.concatenate([mid_vector, jnp.array(X_wo_diagnosis)], axis=1)
y = jnp.array(y_diagnosis)
print(X_aug.shape)

(569, 31)


In [79]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv

In [80]:
# beta

In [81]:
y_hat = X_aug @ beta
y_hat

Array([-4.46199775e-02,  1.57956362e-01, -1.31289959e-01, -1.91767693e-01,
        1.65051579e-01,  3.01050663e-01,  1.50126934e-01,  3.67954075e-01,
        3.01634848e-01, -1.54515386e-01,  4.52479720e-01,  1.40955001e-01,
        3.58264744e-01,  5.42231441e-01,  3.74186397e-01,  6.39593601e-03,
        1.99750304e-01, -1.62641257e-02, -6.85929060e-02,  6.72656178e-01,
        1.02791131e+00,  1.12658191e+00,  2.65199423e-01,  1.52276754e-02,
       -1.56616807e-01, -1.81586415e-01,  7.12091327e-02,  1.05504870e-01,
        4.54057753e-03,  4.64961380e-01, -6.38715029e-02,  8.48129094e-02,
       -1.09492987e-01,  4.44120169e-03,  8.23694244e-02,  1.06506452e-01,
        4.10221487e-01,  1.00699592e+00,  7.27141500e-01,  4.37157243e-01,
        6.74050629e-01,  5.68757951e-01, -2.03502417e-01,  3.31049949e-01,
        3.41685176e-01, -7.03067482e-02,  1.29127371e+00,  1.86356544e-01,
        8.88967276e-01,  6.88415885e-01,  9.57031012e-01,  9.28875744e-01,
        9.27657366e-01,  

In [82]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

The average error is: 0.18057730793952942
MSE: 0.05275509133934975
RMSE: 0.2296847701072693
R² is : 0.7743242979049683


In [83]:
y_labeled = ["benign" if val > 0.5 else "malign" for val in y_hat]

In [84]:
y_labeled

['malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'malign',
 'benign',
 'malign',
 'malign',
 'benign',
 'benign',
 'malign',
 'malign',
 'malign',
 'benign',
 'malign',
 'benign',
 'malign',
 'benign',
 'benign',
 'benign',