# Insurance

In [44]:
import pandas as pd
import jax
import jax.numpy as jnp

## Load data

In [45]:
# read csv
insurance_data = pd.read_csv("insurance.csv")

In [None]:
# insurance_data.head()
# separate de classification data from the numerical
insurance_data['sex'] = insurance_data['sex'].map({'female': 0, 'male': 1})
insurance_data['smoker'] = insurance_data['smoker'].map({'no': 0, 'yes': 1})
# insurance_data['smoker_bmi'] = insurance_data['smoker'] * insurance_data['bmi'] # by adding this column the estimation is better, going from 0.75 to 0.85
insurance_data = pd.get_dummies(insurance_data, columns=['region'])

In [47]:
print(insurance_data)

      age  sex     bmi  children  smoker      charges  smoker_bmi  \
0      19    0  27.900         0       1  16884.92400       27.90   
1      18    1  33.770         1       0   1725.55230        0.00   
2      28    1  33.000         3       0   4449.46200        0.00   
3      33    1  22.705         0       0  21984.47061        0.00   
4      32    1  28.880         0       0   3866.85520        0.00   
...   ...  ...     ...       ...     ...          ...         ...   
1333   50    1  30.970         3       0  10600.54830        0.00   
1334   18    0  31.920         0       0   2205.98080        0.00   
1335   18    0  36.850         0       0   1629.83350        0.00   
1336   21    0  25.800         0       0   2007.94500        0.00   
1337   61    0  29.070         0       1  29141.36030       29.07   

      region_northeast  region_northwest  region_southeast  region_southwest  
0                False             False             False              True  
1            

## Separate target

In [48]:
# get objective value (charges)
X_wo_charges = insurance_data.drop('charges', axis=1).astype(float).values
y_charges = insurance_data['charges'].astype(float).values

## Augmented vetor

In [49]:
mid_vector = jnp.ones((X_wo_charges.shape[0], 1))
#print(mid_vector)

In [50]:
X_aug = jnp.concatenate([mid_vector, jnp.array(X_wo_charges)], axis=1)
y = jnp.array(y_charges)
print(X_aug.shape)

(1338, 11)


In [51]:
XtX = X_aug.T @ X_aug # @ is the dot product, and is equal to use jnp.dot()
Xty = X_aug.T @ y
beta = jnp.linalg.solve(XtX, Xty) # linalg is linear algebra. It uses decomposition to solve the equation instead to use the direct inv

In [52]:
print(beta)

[ 18742.87        263.63226    -500.42142      23.544441    516.34717
 -20414.88       1443.0771   -20966.58     -21552.49     -22177.121
 -22198.336   ]


## Prediction

In [53]:
# linear regression model
y_hat = X_aug @ beta

In [54]:
print(y_hat)

[22057.406   2122.1504  5773.037  ...  2178.7422  2688.2578 35491.758 ]


## Calculate error

In [55]:
res = y - y_hat
print(f'The average error is: {jnp.mean(jnp.abs(res))}')
mse = jnp.mean(jnp.power(res, 2))
rmse = jnp.sqrt(mse)
print(f'MSE: {mse}')
print(f'RMSE: {rmse}')

The average error is: 2898.850830078125
MSE: 23312316.0
RMSE: 4828.28271484375


In [56]:
rss = jnp.sum(jnp.square(res))
y_med = jnp.mean(y)
ss_total = jnp.sum(jnp.square(y - y_med))
r2 = 1 - (rss / ss_total)
print(f'R² is : {r2}')

R² is : 0.8409180045127869
