In [None]:
# Import required libraries

from jax import grad, jit
import jax.numpy as jnp

# Create a Class to perform Linear Regression

class LinearRegressionJAX():
    """
    Gradient Descent Linear Regression using JAX.

    LinearRegression fits a linear model with coefficients w = (w1, ..., wp)
    to minimize the residual sum of squares between the observed targets in
    the dataset, and the targets predicted by the linear approximation.

    """

    # Initialise parameters via Constructor

    def __init__(
            self, X, y
        ):
            self.w = jnp.zeros((X.shape[1],1))
            self.b = 0.
            self.l_rate = 700
            self.n_iter = 15000
            self.X = X
            self.y = y

    # Compute target variable

    def forward(self,X,w,b):
        return jnp.dot(X,w) + b

    # Create a Loss Function (Objective Function - RMSE)

    def loss_fn(self,w,b,X,y):
        err = self.forward(X,w,b) - y
        return jnp.sqrt(jnp.mean(jnp.square(err)))

    # Calculate the Gradient on the Loss Function

    def fit(self):  # Set up fun for just-in-time compilation with XLA.
      gradW = jit(grad(self.loss_fn, argnums=0))  # A function that evaluates the gradient of the loss function
      gradb = jit(grad(self.loss_fn,argnums=1))

    # Compute and Update the Weights & Bias with the Learning Rate

      for _ in range(self.n_iter):
          dW = gradW(self.w, self.b, self.X, self.y)
          db = gradb(self.w, self.b, self.X, self.y)
          self.w -= dW * self.l_rate
          self.b -= db * self.l_rate
          print("Weight:", self.w, "Bias:", self.b, "Loss:", self.loss_fn(self.w, self.b, self.X, self.y))
      return self.w,self.b,self.loss_fn(self.w, self.b, self.X, self.y)

    def predict(self, X):
        print(self.w,self.b)
        return self.forward(X,self.w,self.b)

In [None]:
# Import Required Libraries for Preprocessing data

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd

# Import Dateset

df = pd.read_csv("ParisHousing.csv")
print(df.shape)

# Slice Predictors & Target Variables

X = df.iloc[:, :-1]
y = df.iloc[:, -1]

# Convert data to JAX array

X = jnp.array(X)
y = jnp.array(y)

y=y.reshape((y.shape[0],1))

# Create Train-Test Split

X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.25)

print("Training dataset : ", X_train.shape, y_train.shape)
print("Testing dataset : ", X_test.shape, y_test.shape)


# Standardize features by removing the mean and scaling to unit variance. (Mean = 0 , SD = 1)

ss = StandardScaler().fit(X_train)

X_train_ss = ss.transform(X_train)
X_test_ss = ss.transform(X_test)

(10000, 17)
Training dataset :  (7500, 16) (7500, 1)
Testing dataset :  (2500, 16) (2500, 1)


In [None]:
lr_jax = LinearRegressionJAX(X_train_ss,y_train)
lr_jax.fit()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 [-3.0960583e+01]
 [ 7.1238884e+01]
 [ 6.6033058e+01]
 [-2.4197609e+00]
 [ 1.7415897e+00]
 [ 4.7380764e+01]
 [ 2.6776665e+01]
 [-1.7781540e+01]] Bias: 5006779.0 Loss: 1890.8757
Weight: [[ 2.8784022e+06]
 [ 1.6728735e+01]
 [ 1.5112915e+03]
 [ 1.4912629e+03]
 [ 1.5663845e+03]
 [-4.8171024e+00]
 [ 1.4338940e+02]
 [ 2.0287373e+00]
 [-3.0960281e+01]
 [ 7.1238846e+01]
 [ 6.6033028e+01]
 [-2.4195714e+00]
 [ 1.7416059e+00]
 [ 4.7380863e+01]
 [ 2.6776722e+01]
 [-1.7781502e+01]] Bias: 5006779.0 Loss: 1890.8756
Weight: [[ 2.8784022e+06]
 [ 1.6728729e+01]
 [ 1.5112915e+03]
 [ 1.4912629e+03]
 [ 1.5663846e+03]
 [-4.8171792e+00]
 [ 1.4338940e+02]
 [ 2.0286257e+00]
 [-3.0960192e+01]
 [ 7.1238762e+01]
 [ 6.6032974e+01]
 [-2.4196355e+00]
 [ 1.7415576e+00]
 [ 4.7380798e+01]
 [ 2.6776701e+01]
 [-1.7781498e+01]] Bias: 5006779.0 Loss: 1890.8757
Weight: [[ 2.8784022e+06]
 [ 1.6728615e+01]
 [ 1.5112915e+03]
 [ 1.4912629e+03]
 [ 1.5663846e+03]
 [

(Array([[ 2.8784022e+06],
        [ 1.6728651e+01],
        [ 1.5112914e+03],
        [ 1.4912628e+03],
        [ 1.5663848e+03],
        [-4.8170528e+00],
        [ 1.4338950e+02],
        [ 2.0287626e+00],
        [-3.0959951e+01],
        [ 7.1238876e+01],
        [ 6.6033058e+01],
        [-2.4195147e+00],
        [ 1.7415797e+00],
        [ 4.7380806e+01],
        [ 2.6776741e+01],
        [-1.7781317e+01]], dtype=float32),
 Array(5006779., dtype=float32, weak_type=True),
 Array(1890.8757, dtype=float32))

In [None]:
y_pred = lr_jax.predict(X_test_ss)

[[ 2.8784022e+06]
 [ 1.6728651e+01]
 [ 1.5112914e+03]
 [ 1.4912628e+03]
 [ 1.5663848e+03]
 [-4.8170528e+00]
 [ 1.4338950e+02]
 [ 2.0287626e+00]
 [-3.0959951e+01]
 [ 7.1238876e+01]
 [ 6.6033058e+01]
 [-2.4195147e+00]
 [ 1.7415797e+00]
 [ 4.7380806e+01]
 [ 2.6776741e+01]
 [-1.7781317e+01]] 5006779.0


In [None]:
y_test = jnp.array(y_test).flatten()
y_pred = jnp.array(y_pred).flatten()

data = {'Actual': y_test,'Predicted': y_pred}
comp_df = pd.DataFrame(data)

In [None]:
comp_df

Unnamed: 0,Actual,Predicted
0,4.835242e+06,4833565.50
1,1.009529e+06,1009013.00
2,1.664806e+06,1664259.50
3,9.363558e+06,9363626.00
4,6.430465e+06,6431387.00
...,...,...
2495,2.188346e+05,218980.00
2496,9.406433e+06,9404879.00
2497,4.763586e+06,4764767.00
2498,2.047409e+06,2047026.25


In [None]:
jnp.sqrt(jnp.mean(jnp.square(jnp.array(comp_df.Predicted) - jnp.array(comp_df.Actual))))

Array(1916.9418, dtype=float32)

In [None]:
comp_df.Predicted

0       4833565.50
1       1009013.00
2       1664259.50
3       9363626.00
4       6431387.00
           ...    
2495     218980.00
2496    9404879.00
2497    4764767.00
2498    2047026.25
2499    9885945.00
Name: Predicted, Length: 2500, dtype: float32

In [None]:
y_pred = lr_jax.predict(X_train_ss)

[[ 2.8784022e+06]
 [ 1.6728651e+01]
 [ 1.5112914e+03]
 [ 1.4912628e+03]
 [ 1.5663848e+03]
 [-4.8170528e+00]
 [ 1.4338950e+02]
 [ 2.0287626e+00]
 [-3.0959951e+01]
 [ 7.1238876e+01]
 [ 6.6033058e+01]
 [-2.4195147e+00]
 [ 1.7415797e+00]
 [ 4.7380806e+01]
 [ 2.6776741e+01]
 [-1.7781317e+01]] 5006779.0
