In [1]:
import jax.numpy as jnp
import pandas as pd
import jax

In [11]:
train = pd.read_csv('./sample_data/california_housing_train.csv')
test = pd.read_csv('./sample_data/california_housing_test.csv')

In [7]:
train

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value
0,-114.31,34.19,15.0,5612.0,1283.0,1015.0,472.0,1.4936,66900.0
1,-114.47,34.40,19.0,7650.0,1901.0,1129.0,463.0,1.8200,80100.0
2,-114.56,33.69,17.0,720.0,174.0,333.0,117.0,1.6509,85700.0
3,-114.57,33.64,14.0,1501.0,337.0,515.0,226.0,3.1917,73400.0
4,-114.57,33.57,20.0,1454.0,326.0,624.0,262.0,1.9250,65500.0
...,...,...,...,...,...,...,...,...,...
16995,-124.26,40.58,52.0,2217.0,394.0,907.0,369.0,2.3571,111400.0
16996,-124.27,40.69,36.0,2349.0,528.0,1194.0,465.0,2.5179,79000.0
16997,-124.30,41.84,17.0,2677.0,531.0,1244.0,456.0,3.0313,103600.0
16998,-124.30,41.80,19.0,2672.0,552.0,1298.0,478.0,1.9797,85800.0


In [42]:
learning_rate = 0.1
threshold = 0.0001

In [64]:
y_mean = train['median_house_value'].mean()
y_std = train['median_house_value'].std()
y_df = train['median_house_value']
X_df = train.drop(columns=['median_house_value'])
X_df_mean = X_df.mean()
X_df_std = X_df.std()
X_df = (X_df-X_df.mean())/X_df.std()
y_df = (y_df - y_df.mean()) / y_df.std()

In [38]:
X = jnp.array(X_df.values)
y = jnp.array(y_df.values).reshape(-1, 1)
def loss_fn(w,b,X,y):
  pred = jnp.dot(X,w)+ b
  return jnp.mean(jnp.square(pred-y))

In [46]:
@jax.jit
def train_step(w,b,X,y,lr):
  dw, db = jax.grad(loss_fn, argnums=(0,1))(w,b,X,y)
  new_w = w-lr*dw
  new_b = b-lr*db
  return new_w,new_b

In [49]:
w=jnp.zeros((8,1))
b = 0.0
epoch=0
loss = loss_fn(w,b,X,y)
while loss > threshold and epoch < 1000:
  print(f"Training epoch: {epoch}/1000")
  w, b = train_step(w,b,X,y,learning_rate)
  loss = loss_fn(w,b,X,y)
  print(f"current_loss {loss}")
  epoch+=1

Training epoch: 0/1000
current_loss 0.8078879714012146
Training epoch: 1/1000
current_loss 0.6853088140487671
Training epoch: 2/1000
current_loss 0.6051297783851624
Training epoch: 3/1000
current_loss 0.5520364046096802
Training epoch: 4/1000
current_loss 0.5163998007774353
Training epoch: 5/1000
current_loss 0.49208101630210876
Training epoch: 6/1000
current_loss 0.4751429259777069
Training epoch: 7/1000
current_loss 0.4630489647388458
Training epoch: 8/1000
current_loss 0.45415717363357544
Training epoch: 9/1000
current_loss 0.4474004805088043
Training epoch: 10/1000
current_loss 0.4420818090438843
Training epoch: 11/1000
current_loss 0.43774375319480896
Training epoch: 12/1000
current_loss 0.4340849220752716
Training epoch: 13/1000
current_loss 0.43090519309043884
Training epoch: 14/1000
current_loss 0.4280712902545929
Training epoch: 15/1000
current_loss 0.42549368739128113
Training epoch: 16/1000
current_loss 0.4231117069721222
Training epoch: 17/1000
current_loss 0.42088386416435

In [59]:
y_test_df = test['median_house_value']
X_test_df = test.drop(columns=['median_house_value'])
X_test_df = (X_test_df - X_df.mean())/X_df.std()
y_test_df = (y_test_df - y_test_df.mean())/y_test_df.std()

In [60]:
X_test = jnp.array(X_test_df.values)
y_test = jnp.array(y_test_df.values).reshape(-1,1)
y_pred = jnp.dot(X_test,w) + b
y_pred

Array([[-750.0598 ],
       [-301.80548],
       [-766.70984],
       ...,
       [-231.45834],
       [  43.02465],
       [-338.29468]], dtype=float32)

In [61]:
y_pred = (y_pred * y_std) + y_mean
y_pred

Array([[-86787456.],
       [-34797236.],
       [-88718592.],
       ...,
       [-26638110.],
       [  5197462.],
       [-39029392.]], dtype=float32)

In [67]:

X_test_df = (test.drop(columns=['median_house_value']) - X_df_mean) / X_df_std
X_test = jnp.array(X_test_df.values)

y_test_actual = jnp.array(test['median_house_value'].values).reshape(-1, 1)

y_pred_norm = jnp.dot(X_test, w) + b
y_pred_usd = (y_pred_norm * y_std) + y_mean

In [68]:
y_pred_usd

Array([[352747.66],
       [212695.33],
       [272465.4 ],
       ...,
       [ 88368.49],
       [146450.88],
       [456648.8 ]], dtype=float32)

In [71]:

y_pred_usd = (jnp.dot(X_test, w) + b) * y_std + y_mean
y_test_actual = (y_test * y_std) + y_mean

mae = jnp.mean(jnp.abs(y_pred_usd - y_test_actual))

ss_res = jnp.sum(jnp.square(y_test_actual - y_pred_usd))
ss_tot = jnp.sum(jnp.square(y_test_actual - jnp.mean(y_test_actual)))
r2 = 1 - (ss_res / ss_tot)

print(f"Accuracy Report:")
print(f"- Average Error: ${mae:,.2f}")
print(f"- Variance Explained (R2): {r2:.2%}")

Accuracy Report:
- Average Error: $51,496.49
- Variance Explained (R2): 61.95%
