![JAX](https://jax.readthedocs.io/en/latest/_static/jax_logo_250px.png)

"JAX is Autograd and XLA, brought together for high-performance numerical computing."
* https://jax.readthedocs.io/en/latest

You may need to install `jaxlib`.  If you've installed Tensorflow, chances are that you've already also installed `jax` as one of Tensorflow's dependencies, but `jaxlib` is also still required here.

Acknowledgements for this notebook go to:
* https://colindcarroll.com/2019/04/06/exercises-in-automatic-differentiation-using-autograd-and-jax/
* https://coax.readthedocs.io/en/latest/examples/linear_regression/jax.html
* https://coderzcolumn.com/tutorials/artificial-intelligence/guide-to-create-simple-neural-networks-using-jax

In [None]:
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import grad, vmap

## Simple automatic differentiation

In [None]:
fig, ax = plt.subplots()

x = jnp.linspace(-4, 4, 1000)

my_func = jnp.tanh

ax.plot(x, my_func(x))

# ax.plot(x, grad(my_func)(x)) -> won't work
# The gradient needs to be vectorized
# to be applied across all elements of x

ax.plot(x, vmap(grad(my_func))(x))

In [None]:
tgrad = grad(my_func)

In [None]:
tgrad(0.)

In [None]:
fig, ax = plt.subplots()

x = jnp.linspace(-2 * jnp.pi, 2 * jnp.pi, 1000)

y = jnp.cos
d1y = grad(y)
d2y = grad(grad(y))

ax.plot(x, y(x), 'k-', lw=4)
ax.plot(x, vmap(d1y)(x), 'b-')
ax.plot(x, -vmap(d2y)(x), 'w--')

## Linear Regression

In [None]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

In [None]:
# create our dataset -- make_regression make a random regression problem
X, y = make_regression(n_features=3)
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
X_train.shape

In [None]:
X_test.shape

In [None]:
import matplotlib.pyplot as plt

fig,ax = plt.subplots(1,3,figsize=(12,4))
for i in range(3):
  ax[i].scatter(X_train[:, i], y_train, c='b')
  ax[i].scatter(X_test[:, i], y_test, c='g')
fig.tight_layout()
plt.show()

### What steps do we need?

* Using the current parameters, calculate $\hat{y} = wx + b$
* Calculate the resulting loss score: [$J = $mean$(\hat{y} - y_{actual})^2$]
* Update the weights using gradient descent: $w_{i+1} = w_{i} - \alpha\frac{\partial J}{\partial w_i}$
  * and update both $w$ and $b$ this way

In [None]:
# initialize model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}

In [None]:
jnp.array([1.,2.,3.])

In [None]:
params

In [None]:
def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']

In [None]:
def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse

In [None]:
grad_fn = grad(loss_fn)

In [None]:
grad_fn

In [None]:
def update(params, grads):
    
    for i in range(len(params['w'])):
        newval = params['w'][i] - 0.05 * grads['w'][i]
        params['w'] = params['w'].at[i].set(newval)
    
    params['b'] = params['b'] - 0.05 * grads['b']

    # a better way:
    # return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
    
    return params

In [None]:
# the main training loop
for _ in range(100):
    loss = loss_fn(params, X_train, y_train)
    print(loss)

    grads = grad_fn(params, X_train, y_train)
    params = update(params, grads)

In [None]:
params

In [None]:
from sklearn.metrics import r2_score

test_preds = forward(params, X_test)
train_preds = forward(params, X_train)
print("Test  MSE Score : {:.2f}".format(loss_fn(params, X_test, y_test)))
print("Train MSE Score : {:.2f}".format(loss_fn(params, X_train, y_train)))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), y_train)))

In [None]:
from sklearn.linear_model import LinearRegression

In [None]:
model = LinearRegression()

In [None]:
model.fit(X_train, y_train)

In [None]:
model.coef_, model.intercept_

In [None]:
from sklearn.metrics import mean_squared_error

test_preds = model.predict(X_test)
train_preds = model.predict(X_train)
print("Test  MSE Score : {:.2f}".format(mean_squared_error(test_preds, y_test)))
print("Train MSE Score : {:.2f}".format(mean_squared_error(train_preds, y_train)))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds, y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds, y_train)))

## Classification

### What steps do we need *for logistic regression*?

* Using the current parameters, calculate $\hat{y} = 1 / (1 + e^{-(wx+b)})$
* Calculate the resulting loss score: [$J = $mean$(- y_{actual} \log(\hat{y}) - (1 - y_{actual})\log(1 - \hat{y}))$]
* Update the weights using gradient descent: $w_{i+1} = w_{i} - \alpha\frac{\partial J}{\partial w_i}$
  * and update both $w$ and $b$ this way

In [None]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

X, y = datasets.load_breast_cancer(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, stratify=y, random_state=123)

In [None]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

In [None]:
X_train.shape, y_train.shape

In [None]:
X_test.shape, y_test.shape

In [None]:
# model weights
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}

In [None]:
X.shape

In [None]:
params['w'].shape

In [None]:
def forward(params, X):
    #return jnp.dot(X, params['w']) + params['b']
    return 1 / (1 + jnp.exp(-jnp.dot(X, params['w']) - params['b']))

In [None]:
forward(params,X_train)

In [None]:
def loss_fn(params, X, y):
    preds = forward(params, X)
    return (- y * jnp.log(preds) - (1 - y) * jnp.log(1 - preds)).mean()

In [None]:
grad_fn = grad(loss_fn)

In [None]:
def update(params, grads):

    for i in range(len(params['w'])):
        newval = params['w'][i] - 0.05 * grads['w'][i]
        params['w'] = params['w'].at[i].set(newval)

    params['b'] = params['b'] - 0.05 * grads['b']
    
    # Better way
    # return jax.tree_map(lambda p, g: p - 0.05 * g, params, grads)
    
    return params

In [None]:
# the main training loop
for _ in range(100):
    loss = loss_fn(params, X_test, y_test)
    print(loss)

    grads = grad_fn(params, X_train, y_train)
    params = update(params, grads)

In [None]:
params

In [None]:
train_preds = forward(params, X_train)
train_preds = jnp.where(train_preds > 0.5, 1, 0)

test_preds = forward(params, X_test)
test_preds = jnp.where(test_preds > 0.5, 1, 0)

In [None]:
train_preds[:5]

In [None]:
y_train[:5]

In [None]:
test_preds[:5]

In [None]:
y_test[:5]

In [None]:
from sklearn.metrics import accuracy_score

print("Train Accuracy : {:.2f}".format(accuracy_score(y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(y_test, test_preds)))

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
model = LogisticRegression()

In [None]:
model.fit(X_train, y_train)

In [None]:
model.coef_

In [None]:
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)

In [None]:
print("Train Accuracy : {:.2f}".format(accuracy_score(y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(y_test, test_preds)))

# Neural Network Time

## Regression

In [None]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

#from sklearn.datasets import load_boston
#from sklearn.datasets import fetch_california_housing
#X,Y = fetch_california_housing(return_X_y=True)

In [None]:
  
#X, Y = datasets.load_boston(return_X_y=True)

import pandas as pd
import numpy as np

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
X = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
Y = raw_df.values[1::2, 2]
    
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, random_state=123)

#from tensorflow.keras.datasets import boston_housing
#(train_data, train_targets), (test_data, test_targets) = boston_housing.load_data()

In [None]:
X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

In [None]:
samples, features = X_train.shape

In [None]:
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

In [None]:
mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

In [None]:
def InitializeWeights(layer_sizes, seed):
    weights = []

    for i, units in enumerate(layer_sizes):
        if i==0:
            w = jax.random.uniform(key=seed, 
                                   shape=(units, features), 
                                   minval=-1.0, maxval=1.0, 
                                   dtype=jnp.float32)
        else:
            w = jax.random.uniform(key=seed, 
                                   shape=(units, layer_sizes[i-1]), 
                                   minval=-1.0, maxval=1.0,
                                   dtype=jnp.float32)

        b = jax.random.uniform(key=seed, 
                               shape=(units,), 
                               minval=-1.0, maxval=1.0, 
                               dtype=jnp.float32)

        weights.append([w,b])

    return weights

In [None]:
seed = jax.random.PRNGKey(123)

weights = InitializeWeights([64,64,1], seed)

for w in weights:
    print(w[0].shape, w[1].shape)

In [None]:
def Relu(x):
    return jnp.maximum(x, jnp.zeros_like(x)) # max(0,x)

In [None]:
# Example 1

x = jnp.array([-1,0,1,-2,4,-6,5])
Relu(x)

In [None]:
# Example 2

x = jnp.array([[-1,0,1,-2,4,-6,5],
               [1,2,4,-5,-6,7,9]])
Relu(x)

In [None]:
def LinearLayer(weights, input_data, activation):
    w, b = weights
    out = jnp.dot(input_data, w.T) + b
    return activation(out)

In [None]:
len(weights)

In [None]:
weights[0][0].shape, weights[0][1].shape

In [None]:
def ForwardPass(weights, input_data):
    layer_out = input_data

    for i in range(len(weights[:-1])):
        layer_out = LinearLayer(weights[i], layer_out, Relu)

    activation_self = lambda x: x
    preds = LinearLayer(weights[-1], layer_out, activation_self)
    
    # can check shape here
    # squeeze below will get rid of dims that have length 1
    # see e.g. # help(jnp.zeros(3).squeeze)
    # print(preds.shape)

    return preds.squeeze()

In [None]:
preds = ForwardPass(weights, X_train)

preds.shape

In [None]:
def MeanSquaredErrorLoss(weights, input_data, actual):
    preds = ForwardPass(weights, input_data)
    return jnp.power(actual - preds, 2).mean()

In [None]:
def CalculateGradients(weights, input_data, actual):
    Grad_MSELoss = grad(MeanSquaredErrorLoss)
    gradients = Grad_MSELoss(weights, input_data, actual)
    return gradients

In [None]:
def TrainModel(weights, X, Y, learning_rate, epochs):
    for i in range(epochs):
        loss = MeanSquaredErrorLoss(weights, X, Y)
        gradients = CalculateGradients(weights, X, Y)

        ## Update Weights
        for j in range(len(weights)):
            weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
            weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases

        if i%100 ==0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(loss))

In [None]:
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 1500
layer_sizes = [64,64,1]

weights = InitializeWeights(layer_sizes, seed)

TrainModel(weights, X_train, Y_train, learning_rate, epochs)

## Making predictions

In [None]:
train_preds = ForwardPass(weights, X_train)

train_preds[:5], Y_train[:5]

In [None]:
test_preds = ForwardPass(weights, X_test)

test_preds[:5], Y_test[:5]

In [None]:
print("Test  MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_train, Y_train)))

In [None]:
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))

## Supplementary

In the below, the material is extended to consider:
* training with batches of data
* neural net for classification
  * this can again be thought of rather simply: use a different loss function, a different activation function for the final layer, and a different metric for assessing performance

In [None]:
# Before we found better performance for 130 epochs
# since this is overfitting

# We also did this in batches rather than using all our data at once

def UpdateWeights(learning_rate, weights, gradients):
    for j in range(len(weights)): ## Update Weights
        weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
        weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases

def TrainModelInBatches(weights, X, Y, learning_rate, epochs, batch_size=32):
    for i in range(epochs):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss = MeanSquaredErrorLoss(weights, X_batch, Y_batch) ## Loss of batch
            gradients = CalculateGradients(weights, X_batch, Y_batch)
            losses.append(loss) ## Record Loss

            UpdateWeights(learning_rate, weights, gradients) ## Update Weights

        if i % 100 == 0: ## Print MSE every 100 epochs
            print("MSE : {:.2f}".format(jnp.array(losses).mean()))

In [None]:
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 130
layer_sizes = [64,64,1]

weights = InitializeWeights(layer_sizes, seed)

TrainModelInBatches(weights, X_train, Y_train, learning_rate, epochs, batch_size=32)

In [None]:
test_preds = ForwardPass(weights, X_test)
train_preds = ForwardPass(weights, X_train)
print("Test  MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_test, Y_test)))
print("Train MSE Score : {:.2f}".format(MeanSquaredErrorLoss(weights, X_train, Y_train)))
print("Test  R^2 Score : {:.2f}".format(r2_score(test_preds.squeeze(), Y_test)))
print("Train R^2 Score : {:.2f}".format(r2_score(train_preds.squeeze(), Y_train)))

# Classification

In [None]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

X, Y = datasets.load_breast_cancer(return_X_y=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape
classes = jnp.unique(Y)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

In [None]:
samples, features, classes

In [None]:
# Same as Neural Network Regression Section

mean = X_train.mean(axis=0)
std = X_train.std(axis=0)

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

In [None]:
# Same as Neural Network Regression Section

def InitializeWeights(layer_sizes, seed):
    weights = []

    for i, units in enumerate(layer_sizes):
        if i==0:
            w = jax.random.uniform(key=seed, 
                                   shape=(units, features), 
                                   minval=-1.0, maxval=1.0, 
                                   dtype=jnp.float32)
        else:
            w = jax.random.uniform(key=seed, 
                                   shape=(units, layer_sizes[i-1]), 
                                   minval=-1.0, maxval=1.0,
                                   dtype=jnp.float32)

        b = jax.random.uniform(key=seed, 
                               shape=(units,), 
                               minval=-1.0, maxval=1.0, 
                               dtype=jnp.float32)

        weights.append([w,b])

    return weights

In [None]:
# Same as Neural Network Regression Section

def Relu(x):
    return jnp.maximum(x, jnp.zeros_like(x)) # max(0,x)

In [None]:
# New!
# We need this for the final output layer

def Sigmoid(x):
    return 1 / (1 + jnp.exp(-1 * x))

In [None]:
# Same as Neural Network Regression Section

def LinearLayer(weights, input_data, activation):
    w, b = weights
    out = jnp.dot(input_data, w.T) + b
    return activation(out)

In [None]:
# The activation function for the output layer is new here

def ForwardPass(weights, input_data):
    layer_out = input_data

    for i in range(len(weights[:-1])):
        layer_out = LinearLayer(weights[i], layer_out, Relu)

    # not needed -> activation_self = lambda x: x
    preds = LinearLayer(weights[-1], layer_out, Sigmoid)
    
    return preds.squeeze()

In [None]:
# Rather than MSE for Loss Function....
# def MeanSquaredErrorLoss(weights, input_data, actual):
#     preds = ForwardPass(weights, input_data)
#     return jnp.power(actual - preds, 2).mean()

# We use negative log loss function, appropriate to the binary cross entropy
def NegLogLoss(weights, input_data, actual):
    preds = ForwardPass(weights, input_data)
    return (- actual * jnp.log(preds) - (1 - actual) * jnp.log(1 - preds)).mean()

In [None]:
def CalculateGradients(weights, input_data, actual):
    
    # Previsou for regression:
    # Grad_MSELoss = jax.grad(MeanSquaredErrorLoss)
    # gradients = Grad_MSELoss(weights, input_data, actual)
    
    # Now for classification:
    Grad_NegLogLoss = grad(NegLogLoss)
    gradients = Grad_NegLogLoss(weights, input_data, actual)
    
    return gradients

In [None]:
def TrainModel(weights, X, Y, learning_rate, epochs):
    for i in range(epochs):
        # loss = MeanSquaredErrorLoss(weights, X, Y)
        loss = NegLogLoss(weights, X, Y)
        gradients = CalculateGradients(weights, X, Y)

        ## Update Weights
        for j in range(len(weights)):
            weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
            weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases

        if i%100 ==0: ## Print MSE every 100 epochs
            # print("MSE : {:.2f}".format(loss))
            print("NegLogLoss : {:.2f}".format(loss))

In [None]:
# Same as Neural Network Regression Section

seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e2)
epochs = 1500
layer_sizes = [5,10,15,1]

weights = InitializeWeights(layer_sizes, seed)

TrainModel(weights, X_train, Y_train, learning_rate, epochs)

## Making predictions

In [None]:
train_preds = ForwardPass(weights, X_train)

train_preds = (train_preds > 0.5).astype(jnp.float32)

train_preds[:5], Y_train[:5]

In [None]:
test_preds = ForwardPass(weights, X_test)

test_preds = (test_preds > 0.5).astype(jnp.float32)

test_preds[:5], Y_test[:5]

In [None]:
print("Test  NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_train, Y_train)))

In [None]:
print("Train Accuracy : {:.2f}".format(accuracy_score(Y_train, train_preds)))
print("Test  Accuracy : {:.2f}".format(accuracy_score(Y_test, test_preds)))

In [None]:
def UpdateWeights(learning_rate, weights, gradients):
    for j in range(len(weights)): ## Update Weights
        weights[j][0] -= learning_rate * gradients[j][0] ## Update Weights
        weights[j][1] -= learning_rate * gradients[j][1] ## Update Biases

def TrainModelInBatches(weights, X, Y, learning_rate, epochs, batch_size=32):
    for i in range(epochs):
        batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

        losses = [] ## Record loss of each batch
        for batch in batches:
            if batch != batches[-1]:
                start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
            else:
                start, end = int(batch*batch_size), None

            X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

            loss = NegLogLoss(weights, X_batch, Y_batch)
            gradients = CalculateGradients(weights, X_batch, Y_batch)
            losses.append(loss) ## Record Loss

            UpdateWeights(learning_rate, weights, gradients) ## Update Weights

        if i % 100 == 0: ## Print LogLoss every 100 epochs
            print("NegLogLoss : {:.2f}".format(jnp.array(losses).mean()))

In [None]:
seed = jax.random.PRNGKey(42)
learning_rate = jnp.array(1/1e3)
epochs = 1000

layer_sizes = [5,10,15,1]

weights = InitializeWeights(layer_sizes, seed)

TrainModelInBatches(weights, X_train, Y_train, learning_rate, epochs, batch_size=16)

In [None]:
test_preds = ForwardPass(weights, X_test)
train_preds = ForwardPass(weights, X_train)
print("Test  NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_test, Y_test)))
print("Train NegLogLoss Score : {:.2f}".format(NegLogLoss(weights, X_train, Y_train)))
print("Test  Accuracy Score : {:.2f}".format(accuracy_score(test_preds.squeeze(), Y_test)))
print("Train Accuracy Score : {:.2f}".format(accuracy_score(train_preds.squeeze(), Y_train)))