#Importing Libraries

In [100]:
try :
  import jax
except ModuleNotFoundError:
  %pip install --upgrade -qq pip jax jaxlib
  import jax
try :
  import flax
except ModuleNotFoundError:
  %pip install --upgrade -qq git+https://github.com/google/flax.git
  import flax

from jax import lax, random, numpy as jnp
import sklearn
from jax import grad
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from flax.core import freeze, unfreeze
from flax import linen as nn
import numpy as np
import optax
from sklearn.metrics import classification_report

from jax.config import config
config.enable_omnistaging()

see https://github.com/google/jax/blob/main/design_notes/omnistaging.md
  "enable_omnistaging() is a no-op in JAX versions 0.2.12 and higher;\n"


#Initialisatiion of Global variables

In [101]:
#Define number of Samples
n_samples = 10000

#Initialise learning rate
lr = 0.1

#Initialise number of epochs 
epochs = 75

#Initialise loss at which epoch is to be printed (In this case it will be multiples of 5)
log_period_epoch = 5

#Initialise number of classes
number_of_class = 3 #2 for make_moons and 3 for make_blobs

#Number of Neurons in input layer
layer1 = 2

#Number of Neurons in output layer
layer2 = number_of_class

#Dimensions of weights and biases
x_dim = 2  
y_dim = number_of_class

#Can be any number
RS = 0
seed = 23

#Function to convert data into categories

In [102]:
def one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k """
    return np.array(x[:, None] == np.arange(k), dtype)

#Defifining MC Dropout 


In [103]:
class MCDropout(nn.Dropout):
  def call(self, inputs):
    return super().call(inputs, training = True)

#MLP using FLAX with/without MC Dropout

In [115]:
#Uncomment the line below for using make_moons 
#X,y = sklearn.datasets.make_moons(n_samples=n_samples, shuffle=True,random_state=None)

#Uncomment the line below for using make_blobs
X,y = sklearn.datasets.make_blobs(n_samples=n_samples, shuffle=True,random_state=None)

n_feat = X.shape[1]

#Dividing into train and test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, stratify=y, random_state=RS
)

#Using one hot encoding to divide output into classes
y_train = one_hot(y_train, number_of_class)
y_test = one_hot(y_test, number_of_class)


In [105]:
#Defining Multi Layer Perceptron
class mlp(nn.Module): 
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=layer1, use_bias = True)(x)
        x = nn.relu(x)

        #Uncomment line below to use MC Dropout
        #x = MCDropout(rate=0.2)(x, deterministic=True)

        x = nn.Dense(features=layer2, use_bias = True)(x)
        x = nn.softmax(x)
        return x
      
model = mlp()

In [106]:
xs = X_train
ys = y_train

#Initialising weights and biases randomly
key, w_key, b_key = random.split(random.PRNGKey(seed), num=3)
W = random.normal(w_key, (x_dim, y_dim))  # weight
b = random.normal(b_key, (y_dim,))  # bias
true_params = freeze({'params': {'bias': b, 'kernel': W}})


In [107]:
#Defining Loss function as binary cross entropy
def binary_cross_entropy(xs, ys):   
    def temp_cross_entropy(params):
        def temp(x, y):
            pred = model.apply(params, x)           
            #Uncomment the line below for using make_moons 
            #temp_loss = -y[:,0]*jnp.log(pred[:,0])-y[:,1]*jnp.log(pred[:,1])

            #Uncomment the line below for using make_blobs
            temp_loss = -y[:,0]*jnp.log(pred[:,0])-y[:,1]*jnp.log(pred[:,1])-y[:,2]*jnp.log(pred[:,2])
            return temp_loss
        return jnp.mean(temp(xs, ys))
    return jax.jit(temp_cross_entropy)  

loss = binary_cross_entropy(xs, ys)
value_and_grad_fn = jax.value_and_grad(loss)

#Training without optax

In [108]:
params = model.init(key, xs)
for epoch in range(epochs):
    loss, grads = value_and_grad_fn(params)
    params = jax.tree_multimap(lambda p, g: p - lr * g, params, grads)
    if epoch % log_period_epoch == 0:
        print(f'epoch {epoch}, loss = {loss}')





epoch 0, loss = 2.889653205871582
epoch 5, loss = 0.5823025703430176
epoch 10, loss = 0.41217637062072754
epoch 15, loss = 0.3339276611804962
epoch 20, loss = 0.2888086438179016
epoch 25, loss = 0.2569538950920105
epoch 30, loss = 0.23192162811756134
epoch 35, loss = 0.2112276554107666
epoch 40, loss = 0.1936613917350769
epoch 45, loss = 0.17850534617900848
epoch 50, loss = 0.16529275476932526
epoch 55, loss = 0.15367868542671204
epoch 60, loss = 0.14340732991695404
epoch 65, loss = 0.1342749148607254
epoch 70, loss = 0.12611515820026398


In [109]:
#Predicting output of test dataset
y_pred = model.apply(params, X_test)

#Rounding off the probabilities
y_pred = jnp.where(y_pred < 0.5, y_pred, 1.0)
y_pred = jnp.where(y_pred >= 0.5, y_pred, 0.0)

#Printing classification report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      0.99      1.00       834
           1       1.00      1.00      1.00       833
           2       1.00      1.00      1.00       833

   micro avg       1.00      1.00      1.00      2500
   macro avg       1.00      1.00      1.00      2500
weighted avg       1.00      1.00      1.00      2500
 samples avg       1.00      1.00      1.00      2500



  _warn_prf(average, modifier, msg_start, len(result))


#Training with optax

In [110]:
#Using adam as an optimizer
opt_adam = optax.adam(learning_rate=lr)
opt_state = opt_adam.init(params)

In [111]:
params = model.init(key, xs)  

for epoch in range(epochs):
    loss, grads = value_and_grad_fn(params)
    updates, opt_state = opt_adam.update(grads, opt_state)  
    params = optax.apply_updates(params, updates)
    if epoch % log_period_epoch == 0:
        print(f'epoch {epoch}, loss = {loss}')

epoch 0, loss = 2.889653205871582
epoch 5, loss = 0.6834388375282288
epoch 10, loss = 0.2309263050556183
epoch 15, loss = 0.10921990871429443
epoch 20, loss = 0.05437207594513893
epoch 25, loss = 0.031396765261888504
epoch 30, loss = 0.02037673443555832
epoch 35, loss = 0.014530753716826439
epoch 40, loss = 0.011463300324976444
epoch 45, loss = 0.009741807356476784
epoch 50, loss = 0.008639251813292503
epoch 55, loss = 0.007859257981181145
epoch 60, loss = 0.00728395814076066
epoch 65, loss = 0.006848851218819618
epoch 70, loss = 0.00650702603161335


In [112]:
#Predicting output of test dataset
y_pred = model.apply(params, X_test)

#Rounding off the probabilities
y_pred = jnp.where(y_pred < 0.5, y_pred, 1.0)
y_pred = jnp.where(y_pred >= 0.5, y_pred, 0.0)

#Printing classification report
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       834
           1       1.00      1.00      1.00       833
           2       1.00      1.00      1.00       833

   micro avg       1.00      1.00      1.00      2500
   macro avg       1.00      1.00      1.00      2500
weighted avg       1.00      1.00      1.00      2500
 samples avg       1.00      1.00      1.00      2500



#Important References

Youtube References
1. [MC Dropout](https://youtu.be/eHT0raFtl1Q?t=181)
2. [Machine Learning with FLAX - FROM ZERO TO HERO](https://youtu.be/5eUSmJvK8WA)
3. [Future of ML research in JAX/FLAX](https://youtu.be/7Zau-5ozWfg)
4. [MLP Mixer in Flax and Pytorch](https://youtu.be/HqytB2GUbHA)

Python Notebooks
5. [FLAX Linen tutorial](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)
6. [FLAX basics](https://colab.research.google.com/github/BertrandRdp/flax/blob/master/docs/notebooks/flax_basics.ipynb)