In [30]:
import pandas as pd
import jax
import jax.numpy as jnp

In [31]:
# read csv
insurance_data = pd.read_csv("insurance.csv")

In [32]:
# insurance_data.head()
# separate de classification data from the numerical
insurance_data['sex'] = insurance_data['sex'].map({'female': 0, 'male': 1})
insurance_data['smoker'] = insurance_data['smoker'].map({'no': 0, 'yes': 1})
insurance_data = pd.get_dummies(insurance_data, columns=['region'])

In [33]:
print(insurance_data)

      age  sex     bmi  children  smoker      charges  region_northeast  \
0      19    0  27.900         0       1  16884.92400             False   
1      18    1  33.770         1       0   1725.55230             False   
2      28    1  33.000         3       0   4449.46200             False   
3      33    1  22.705         0       0  21984.47061             False   
4      32    1  28.880         0       0   3866.85520             False   
...   ...  ...     ...       ...     ...          ...               ...   
1333   50    1  30.970         3       0  10600.54830             False   
1334   18    0  31.920         0       0   2205.98080              True   
1335   18    0  36.850         0       0   1629.83350             False   
1336   21    0  25.800         0       0   2007.94500             False   
1337   61    0  29.070         0       1  29141.36030             False   

      region_northwest  region_southeast  region_southwest  
0                False             Fal

In [34]:
# get objective value (charges)
X_wo_charges = insurance_data.drop('charges', axis=1).astype(float).values
y_charges = insurance_data['charges'].astype(float).values

In [35]:
mid_vector = jnp.ones((X_wo_charges.shape[0], 1))
#print(mid_vector)

In [36]:
X_aug = jnp.concatenate([mid_vector, jnp.array(X_wo_charges)], axis=1)
y = jnp.array(y_charges)
print(X_aug.shape)

(1338, 10)


In [37]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv

In [38]:
print(beta)

[ 68552.94       256.86377   -131.25607    339.21402    475.42694
  23848.598   -80492.38    -80845.27    -81527.51    -81452.27   ]


In [39]:
# linear regression model
y_hat = X_aug @ beta

In [40]:
print(y_hat)

[25293.742   3448.4062  6706.703  ...  4149.0156  1246.5234 37085.9   ]


In [42]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')

The average error is: 4170.916015625
MSE: 36501892.0
RMSE: 6041.6796875


In [43]:
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

R² is : 0.7509130239486694
