#Neural Network (NN) using pure JAX

In [11]:
#Importing all the necessary libraries

import jax.numpy as jnp
from jax import grad
from jax import random

In [12]:
#Parameters

layer_sizes = [5,2,3]  #Hidden units in each layer (total 3 layers)

key = random.PRNGKey(0)
key, *keys = random.split(key, len(layer_sizes))

key, *keys = random.split(key,3)  #len(layer_sizes) = 3
inputs = random.normal(keys[0],(8,5))
targets = random.normal(keys[1], (8,3))
batch = (inputs,targets)

In [13]:
key

array([538105296,  96102591], dtype=uint32)

In [14]:
#subkeys
keys

[array([3126261553, 3539587250], dtype=uint32),
 array([1660104999, 2332457458], dtype=uint32)]

In [15]:
#length of subkeys
len(keys)

2

In [16]:
#Initialization function
def init_layer(key, n_in, n_out):
  k1,k2 = random.split(key)
  w = random.normal(key, (n_in,n_out))
  b = random.normal(key,(n_out,))
  return w,b


#Utility function for predicting output
def predict(params, inputs):
  for w,b in params:
    outputs = jnp.dot(inputs,w)+b
    inputs = jnp.tanh(outputs)
  return outputs

#Utility function for calcultaing loss
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params,inputs)
  return jnp.sum((predictions-targets)**2)

In [17]:
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
params

[(DeviceArray([[-3.3101373 , -0.54429305],
               [ 0.50362825, -0.09967706],
               [ 1.928792  ,  0.2380951 ],
               [-1.0271916 ,  0.38994825],
               [ 0.6086258 , -1.0101831 ]], dtype=float32),
  DeviceArray([-0.24502717, -0.8321449 ], dtype=float32)),
 (DeviceArray([[-0.48060644, -0.57022816,  0.28993058],
               [ 0.68064284, -0.76739717,  1.851759  ]], dtype=float32),
  DeviceArray([-1.9195611,  1.9178468, -1.3020369], dtype=float32))]

In [18]:
#Loss
loss(params,batch)

DeviceArray(113.28918, dtype=float32)

In [19]:
step_size = 1e-2

#Training for 20 epochs

for _ in range(20):
  grads = grad(loss)(params,batch)
  params = [(w-step_size*dw, b-step_size*db) for (w,b),(dw,db) in zip(params,grads)]

print(loss(params,batch))


11.241646
