Link to Colab File : https://colab.research.google.com/drive/1kqVBkbwWwI9tRVGm5qXdcCgk--FYgwh7?usp=sharing

In [None]:
import numpy.random as npr
import jax
import numpy as np
from jax import jit, grad
import jax.numpy as jnp
from sklearn.datasets import load_boston
from sklearn.datasets import load_digits
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import KFold
from functools import partial
from jax.scipy.special import logsumexp

In [None]:
class NeuralNetwork():
  def __init__(self, layers_sizes, layers_activation, type_, frac = 0.01):
    """
    Parameters :
      > layers_sizes : Python List. Number of neurons in each layer.
      > layers_activation : Python List. Type of activation for each hidden layer.
                          Avaliable Options are {sigmoid, relu, identity}
      > type_ : int. 0 for Regression, 1 for Classification.
      > fac : int. Multiplicative factor for weights.
    """
    self.layers_sizes = layers_sizes.copy()
    self.layers_activation = layers_activation.copy()
    self.num_layers = len(layers_sizes)
    self.type_ = type_
    self.frac = frac
  
  def __initialise_params(self):
    """
    This function randomly initialises the weights for each layer.
    """
    layer_sizes = self.layers_sizes
    parameters = []
    for i in range(1,len(layer_sizes)):
      Wi = self.frac*npr.randn(layer_sizes[i],layer_sizes[i-1])
      bi = self.frac*npr.randn(layer_sizes[i])
      parameters.append([Wi,bi])
    return parameters
  
  def __GetBatches(self, batch_size, X, y):
    """
     Parameters :
      > batch_size : int.
      > X : numpy array.
      > y : numpy array.
     Returns :
      > batches : python list. Contains tuples of (x,y), each tuple is a batch.
    """
    batches = []
    num_batches = X.shape[0]//batch_size
    if (X.shape[0]%batch_size!=0):
      num_batches += 1
    data = np.hstack((np.copy(X), np.copy(y)))
    n = len(data)
    for i in range(num_batches):
      batch = data[i*(batch_size):min(n,(i+1)*batch_size),:]
      X_t = batch[:,:-1]
      y_t = batch[:,-1]
      if self.type_ == 1:
        # Creating One-Hot Respresentation incase of classification
        y_t = self.__oneHot(y_t)
      batches.append((X_t,y_t))
    return batches
  
  def __oneHot(self,y,num_classes = 10):
    """
    Parameters :
      > y : numpy array.
      > num_classes : int. The number of classes that y can take.

    Returns :
      > One Hot Encoded respresentation of y, shape = (y.shape[0],num_classes). 
    """
    one_hot = np.zeros((len(y), num_classes))
    for i in range(len(y)):
      one_hot[i][int(y[i])] = 1
    return one_hot

  def __relu(self, x):
    """
    Parameters :
      > x : numpy array.
    Returns :
      > numpy array of element wise relu of x.
    """
    return jax.nn.relu(x)
    
  def __sigmoid(self, x):
    """
    Parameters :
      > x : numpy array.
    Returns :
      > numpy array of element wise sigmoid of x.
    """
    return 1/(1+jnp.exp(-x))
  
  def __identity(self, x):
    """
    Parameters :
      > x : numpy array.
    Returns :
      > Same Numpy array.
    """
    return x

  def __ForwardPass(self, parameters, X_cur):
    """
    Parameters :
      > parameters : Python List. Weights of the model.
      > X_cur : The input of the current batch.

    Returns :
      > The output of the final layer of the Network.
        Incase of classification, log of softmax is returned.
    """
    A = X_cur
    activation_functions = self.layers_activation
    i = 0
    for W,b in parameters[:-1]:
      Z = jnp.dot(A, W.T) + b
      if activation_functions[i] == "relu":
        A = self.__relu(Z)
      elif activation_functions[i] == "sigmoid":
        A = self.__sigmoid(Z)
      else:
        A = self.__identity(Z)
      i += 1
    W,b = parameters[-1]
    A = jnp.dot(A, W.T) + b
    if self.type_ == 1:
      # Incase of classification log of softmax is returned
      return A - logsumexp(A, axis=1, keepdims=True)
    else:
      # Incase of regression a 1D Array is returned.
      return A.reshape(-1)
  
  def cost_reg(self, parameters, batch):
    """
    Parameters :
      > parameters : Python List. Weights of the model.
      > batch : Python tuple. The X and y values for the current batch.
    
    Returns :
      > Mean of Sum of squared errors
    
    The function is used to generate gradients in the regression case. 
    """
    X_cur, y_cur = batch
    n = len(X_cur)
    y_hat = self.__ForwardPass(parameters, X_cur)
    return jnp.dot((y_cur-y_hat).T,(y_cur-y_hat))/n
  
  def cost_clas(self, parameters, batch):
    """
    Parameters :
      > parameters : Python List. Weights of the model.
      > batch : Python tuple. The X and y values for the current batch.
    
    Returns :
      > Mean of cross-entropy loss.
    
    The function is used to generate gradients in the classification case. 
    """
    X_cur, y_cur = batch
    y_hat = self.__ForwardPass(parameters, X_cur)
    return -jnp.mean(jnp.sum(y_hat*y_cur, axis=1))
  
  @partial(jit, static_argnums=(0,))
  def __update(self, parameters, batch, lr):
    """
    Parameter :
      > parameters : Python List, Weights of the model.
      > batch : Python tuple. The X and y values for the current batch.
      > lr : float. Learning Rate.

    This Function updates the gradient.
    This function is jit compiled. Reference : https://github.com/google/jax/issues/1251
    """
    if self.type_ == 0:
      grads = grad(self.cost_reg)(parameters, batch)
    else:
      grads = grad(self.cost_clas)(parameters, batch)
    for i in range(len(parameters)):
      parameters[i][0] -= (lr * grads[i][0])
      parameters[i][1] -= (lr * grads[i][1])
    return parameters
  
  def fit(self, X, y, batch_size, epochs = 150, lr = 0.01, lr_type = "constant"):
    """
    Parameters :
      > X : numpy array. Training Data.
      > y : numpy array. Training Data labels.
      > batch_size : int. Number of batches.
      > epochs : int. Number of epochs.
      > lr : float.  Learning Rate.
      > lr_type : string. if "constant" learning rate remains constant throughout the learning process
                            elif "inverse" learning rate decreases inverly with the epochs.
    """
    X = X.copy()
    y = y.copy()
    y = y.reshape(-1,1)
    input_size = len(X[0])
    parameters = self.__initialise_params()
    lr_cur = lr
    batches = self.__GetBatches(batch_size, X, y)
    for epoch in range(epochs):
      if (lr_type == "inverse"):
        lr_cur = lr_cur/(epoch+1)
      for batch in batches:
        parameters = self.__update(parameters, batch, lr_cur)
    self.parameters = parameters
  
  def predict(self, X):
    """
    Parameters :
      > X : Numpy Array.
    
    Returns :
      > The ouput value or class depending on regression or classification respectivey. 
    """
    parameters = self.parameters
    y_hat = self.__ForwardPass(parameters,X)
    if self.type_ == 1:
      # Classification, picking the class with the highest log value of probablity.
      output = jnp.argmax(y_hat, axis = 1)
      return output
    else:
      return y_hat

In [None]:
def rmse(y, y_hat):
  """
  Parameters :
    > y : Ground Truth.
    > y_hat : Prediction.
  
  Returns :
    > Root Mean squared error.
  """
  n = len(y)
  rmse = 0
  for i in range(n):
    rmse += pow(y_hat[i]-y[i],2)
  rmse = pow(rmse/y.size,0.5)
  return rmse

def accuracy(y,y_hat):
  """
  Parameters :
    > y : Ground Truth.
    > y_hat : Prediction.
  Returns :
  > Accuracy.
  """
  acc = 0
  for i in range(y.size):
    if (y[i]==y_hat[i]):
      acc+=1
  return acc/y.size

def precprocess(X):
    """
    Parameters :
    > X : numpy array.

    Returns :
    > Input numpy array row-wise flattened and all values normalised to [0-1)
    """
    n = len(X)
    return X.reshape((n,-1))/15

In [None]:
# K fold Cross Validation, Regression case on boston Dataset.

a = [13, 64, 128, 128, 64, 1]
b = ["relu","relu","relu","relu"]
data = load_boston()
scaler = MinMaxScaler()
scaler.fit(data["data"])
X = scaler.transform(data["data"])
y = data["target"]
kf = KFold(3)
kf.get_n_splits(X)
err = 0
for train_index, test_index in kf.split(X):
  # Test-Train Split loop
  X_train, y_train = X[train_index], y[train_index]
  X_test, y_test = X[test_index], y[test_index]
  NN = NeuralNetwork(layers_sizes = a, layers_activation = b, type_ = 0)
  NN.fit(X_train,y_train,50, epochs = 1000, lr = 0.001)
  y_hat = NN.predict(X_test)
  err += rmse(y_test,y_hat)
print("Average 3 Fold rmse : ", err/3)



Average 3 Fold rmse :  4.8850346


In [None]:
# K fold Cross Validation, Regression case on boston Dataset.

a = [13, 64, 32, 16, 8, 1]
b = ["relu","relu","relu","relu","relu"]
data = load_boston()
scaler = MinMaxScaler()
scaler.fit(data["data"])
X = scaler.transform(data["data"])
y = data["target"]
kf = KFold(3)
kf.get_n_splits(X)
err = 0
for train_index, test_index in kf.split(X):
  # Test-Train Split loop
  X_train, y_train = X[train_index], y[train_index]
  X_test, y_test = X[test_index], y[test_index]
  NN = NeuralNetwork(layers_sizes = a, layers_activation = b, type_ = 0)
  NN.fit(X_train,y_train,50, epochs = 1000, lr = 0.001)
  y_hat = NN.predict(X_test)
  err += rmse(y_test,y_hat)
print("Average 3 Fold rmse : ", err/3)

Average 3 Fold rmse :  5.34297


In [None]:
# K fold Cross Validation, Classification case on digits Dataset.

a = [64, 128, 128, 10]
b = ["sigmoid","sigmoid","sigmoid"]
data = load_digits()
X = precprocess(data["data"])
y = data["target"]
kf = KFold(3)
kf.get_n_splits(X)
acc = 0
for train_index, test_index in kf.split(X):
  # Test-Train Split loop
  X_train, y_train = X[train_index], y[train_index]
  X_test, y_test = X[test_index], y[test_index]
  NN = NeuralNetwork(layers_sizes = a, layers_activation = b, type_ = 1, frac = 1)
  NN.fit(X_train,y_train, 50, epochs = 1000, lr = 0.02)
  y_hat = NN.predict(X_test)
  acc += accuracy(y_test, y_hat)
print("Average 3 Fold acc : ", acc/3)

Average 3 Fold acc :  0.9220923761825265


In [None]:
# K fold Cross Validation, Classification case on digits Dataset.

a = [64, 128, 128, 64, 32, 10]
b = ["sigmoid","sigmoid","sigmoid","sigmoid","sigmoid"]
data = load_digits()
X = precprocess(data["data"])
y = data["target"]
kf = KFold(3)
kf.get_n_splits(X)
acc = 0
for train_index, test_index in kf.split(X):
  # Test-Train Split loop
  X_train, y_train = X[train_index], y[train_index]
  X_test, y_test = X[test_index], y[test_index]
  NN = NeuralNetwork(layers_sizes = a, layers_activation = b, type_ = 1, frac = 1)
  NN.fit(X_train,y_train, 50, epochs = 1000, lr = 0.02)
  y_hat = NN.predict(X_test)
  acc += accuracy(y_test, y_hat)
print("Average 3 Fold acc : ", acc/3)

Average 3 Fold acc :  0.9276572064552031


#### Reference : 
Examples given in offical JAX Github Repository : https://github.com/google/jax/tree/master/examples