<a href="https://colab.research.google.com/github/guiOsorio/Learning_JAX/blob/master/MNIST_FLAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Flax and JAX
!pip install --upgrade -q "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install --upgrade -q git+https://github.com/google/flax.git

In [2]:
import jax
from jax import lax, random, jit, numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state

import optax

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import functools
from typing import Sequence, Callable, Any, Optional

import numpy as np
import matplotlib.pyplot as plt

In [3]:
def custom_transform(x):
    # A couple of modifications here compared to tutorial #3 since we're using a CNN
    # Input: (28, 28) uint8 [0, 255] torch.Tensor, Output: (28, 28, 1) float32 [0, 1] np array
    return np.expand_dims(np.array(x, dtype=np.float32), axis=2) / 255.

def custom_collate_fn(batch):
    """Provides us with batches of numpy arrays and not PyTorch's tensors."""
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels

mnist_img_size = (28, 28, 1)
batch_size = 128

train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# optimization - loading the whole dataset into memory
train_images = jnp.array(train_dataset.data)
train_lbls = jnp.array(train_dataset.targets)

# np.expand_dims is to convert shape from (10000, 28, 28) -> (10000, 28, 28, 1)
# We don't have to do this for training images because custom_transform does it for us.
test_images = np.expand_dims(jnp.array(test_dataset.data), axis=3)
test_lbls = jnp.array(test_dataset.targets)

print(train_images.shape, train_lbls.shape)
print('---')
print(test_images.shape, test_lbls.shape)
print('---')

for img, lbl in train_loader:
  print(img.shape)
  print(lbl.shape)
  break

# try stacking features in train_loader



(60000, 28, 28) (60000,)
---
(10000, 28, 28, 1) (10000,)
---
(128, 28, 28, 1)
(128,)


In [4]:
# 1. Implementation with nn.compact
class NN_compact(nn.Module):

  @nn.compact
  def __call__(self, x):
    # Flatten x
    x = x.reshape((x.shape[0], -1))

    # Linear + relu
    x = nn.Dense(features=100)(x)
    x = nn.relu(x)

    # Linear + relu
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)

    # Linear + softmax
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

In [5]:
# TRAINING

# Compute loss and update - this will be computed many times, so it's best to jit it
@jit
def training_state(state, imgs, gt_labels):

  def crossEntropy_loss(params):
    logits = NN_compact().apply({'params': params}, imgs) # logits is a vector of probabilities predicted by the model (the highest value in the vector is the prediction)
    one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10) # one hot encoded vector of labels
    # logits.shape and one_hot_gt_labels shape is (batch_size, num_classes)
    loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1)) # axis=-1 means sum over rows
    return loss, logits
  
  (loss, logits), grads = jax.value_and_grad(crossEntropy_loss, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads) # update state params based on grads calculated

  ## Accuracy
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }

  return state, metrics

# One epoch - need to add metrics part
def train_one_epoch(state, dataloader):
  batch_metrics = []
  for cnt, (imgs, labels) in enumerate(dataloader):
    state, metrics = training_state(state, imgs, labels)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)  # pull from the accelerator onto host (CPU)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }

  return state, epoch_metrics_np

# Train State initializer
def create_train_state(key, lr, momentum):
  # Create model
  NN = NN_compact()
  # Initialize parameters
  params = NN.init(key, jnp.ones([1, *mnist_img_size]))['params']
  # Initialize optimizer
  sgd_opt = optax.sgd(lr, momentum)

  return train_state.TrainState.create(apply_fn=NN.apply, params=params, tx=sgd_opt)

In [6]:
# EVALUATION

# Run one evaluation on test set
@jit
def eval_step(state, imgs, gt_labels):
  logits = NN_compact().apply({'params': state.params}, imgs)
  one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
  loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

def evaluate_model(state, test_imgs, test_labels):
  metrics = eval_step(state, test_imgs, test_labels)
  metrics = jax.device_get(metrics) # pull from accelerator to CPU
  metrics = jax.tree_map(lambda x: x.item(), metrics) # get scalar value from array
  return metrics

In [7]:
# FIT

from flax.training import train_state
seed = 0
lr = 0.05
momentum = 0.9
n_epochs = 4

train_state = create_train_state(jax.random.PRNGKey(seed), lr, momentum)

for epoch in range(n_epochs):
  print(f'EPOCH {epoch+1}')

  train_state, train_metrics = train_one_epoch(train_state, train_loader)
  print(f'Train accuracy: {train_metrics["accuracy"]}, Train loss: {train_metrics["loss"]}')

  test_metrics = evaluate_model(train_state, test_images, test_lbls)
  print(f'Test accuracy: {test_metrics["accuracy"]}, Test loss: {test_metrics["loss"]}')
  print(' ')

EPOCH 1
Train accuracy: 0.9029781222343445, Train loss: 0.3185702860355377
Test accuracy: 0.9451999664306641, Test loss: 25.09831428527832
 
EPOCH 2
Train accuracy: 0.9630575776100159, Train loss: 0.12100706249475479
Test accuracy: 0.957099974155426, Test loss: 20.958974838256836
 
EPOCH 3
Train accuracy: 0.9734575152397156, Train loss: 0.08256715536117554
Test accuracy: 0.9715999960899353, Test loss: 13.875977516174316
 
EPOCH 4
Train accuracy: 0.9799178838729858, Train loss: 0.06296219676733017
Test accuracy: 0.9711999893188477, Test loss: 15.750566482543945
 


In [8]:
# 2. Implementation with setup()
class NN_setup(nn.Module):

  def setup(self):
    self.dense1 = nn.Dense(features=100)
    self.dense2 = nn.Dense(features=256)
    self.dense3 = nn.Dense(features=10)

  def __call__(self, x):
    # Flatten x
    x = x.reshape((x.shape[0], -1))
    
    # Linear + relu
    x = self.dense1(x)
    x = nn.relu(x)

    # Linear + relu
    x = self.dense2(x)
    x = nn.relu(x)

    # Linear + softmax
    x = self.dense3(x)
    x = nn.log_softmax(x)
    return x

In [9]:
# TRAINING 2

# Compute loss and update - this will be computed many times, so it's best to jit it
@jit
def training_state(state, imgs, gt_labels):

  def crossEntropy_loss(params):
    logits = NN_setup().apply({'params': params}, imgs) # logits is a vector of probabilities predicted by the model (the highest value in the vector is the prediction)
    one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10) # one hot encoded vector of labels
    # logits.shape and one_hot_gt_labels shape is (batch_size, num_classes)
    loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1)) # axis=-1 means sum over rows
    return loss, logits
  
  (loss, logits), grads = jax.value_and_grad(crossEntropy_loss, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads) # update state params based on grads calculated

  ## Accuracy
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }

  return state, metrics

# One epoch - need to add metrics part
def train_one_epoch(state, dataloader):
  batch_metrics = []
  for cnt, (imgs, labels) in enumerate(dataloader):
    state, metrics = training_state(state, imgs, labels)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)  # pull from the accelerator onto host (CPU)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }

  return state, epoch_metrics_np

# Train State initializer
def create_train_state(key, lr, momentum):
  # Create model
  NN = NN_setup()
  # Initialize parameters
  params = NN.init(key, jnp.ones([1, *mnist_img_size]))['params']
  # Initialize optimizer
  sgd_opt = optax.sgd(lr, momentum)

  return train_state.TrainState.create(apply_fn=NN.apply, params=params, tx=sgd_opt)

In [10]:
# EVALUATION 2

# Run one evaluation on test set
@jit
def eval_step(state, imgs, gt_labels):
  logits = NN_setup().apply({'params': state.params}, imgs)
  one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
  loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

def evaluate_model(state, test_imgs, test_labels):
  metrics = eval_step(state, test_imgs, test_labels)
  metrics = jax.device_get(metrics) # pull from accelerator to CPU
  metrics = jax.tree_map(lambda x: x.item(), metrics) # get scalar value from array
  return metrics

In [11]:
# FIT 2

from flax.training import train_state
seed = 0
lr = 0.05
momentum = 0.9
n_epochs = 4

train_state = create_train_state(jax.random.PRNGKey(seed), lr, momentum)

for epoch in range(n_epochs):
  print(f'EPOCH {epoch+1}')

  train_state, train_metrics = train_one_epoch(train_state, train_loader)
  print(f'Train accuracy: {train_metrics["accuracy"]}, Train loss: {train_metrics["loss"]}')

  test_metrics = evaluate_model(train_state, test_images, test_lbls)
  print(f'Test accuracy: {test_metrics["accuracy"]}, Test loss: {test_metrics["loss"]}')
  print(' ')

EPOCH 1
Train accuracy: 0.9056156277656555, Train loss: 0.3106502890586853
Test accuracy: 0.958899974822998, Test loss: 17.96934700012207
 
EPOCH 2
Train accuracy: 0.9637753963470459, Train loss: 0.1184086874127388
Test accuracy: 0.964199960231781, Test loss: 16.922075271606445
 
EPOCH 3
Train accuracy: 0.9745426177978516, Train loss: 0.08304621279239655
Test accuracy: 0.9711999893188477, Test loss: 14.090845108032227
 
EPOCH 4
Train accuracy: 0.9806023240089417, Train loss: 0.06245078146457672
Test accuracy: 0.9631999731063843, Test loss: 18.837289810180664
 


In [44]:
# 3. Implementation with batch norm and dropout
class NN_regularized(nn.Module):

  @nn.compact 
  def __call__(self, x, train: bool):
    # Flatten x
    x = x.reshape((x.shape[0], -1))

    # Linear + dropout + relu
    x = nn.Dense(features=100)(x)
    x = nn.Dropout(0.2, deterministic=not train)(x)
    x = nn.relu(x)

    # Linear + batch norm + relu
    x = nn.Dense(features=256)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)

    # Linear + softmax
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

In [45]:
# TRAINING 3

# Compute loss and update - this will be computed many times, so it's best to jit it
@jit
def training_state(state, imgs, gt_labels):

  def crossEntropy_loss(params, batch_stats):
    logits, updates = NN_regularized().apply({'params': params, 'batch_stats': batch_stats}, imgs, train=True, rngs={'dropout': jax.random.PRNGKey(0)}, mutable=['batch_stats'])
    # logits is a vector of probabilities predicted by the model (the highest value in the vector is the prediction)
    one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10) # one hot encoded vector of labels
    # logits.shape and one_hot_gt_labels shape is (batch_size, num_classes)
    loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1)) # axis=-1 means sum over rows
    return loss, (logits, updates)
  
  (loss, (logits, updates)), grads = jax.value_and_grad(crossEntropy_loss, argnums=0, has_aux=True)(state.params, state.batch_stats)
  state = state.apply_gradients(grads=grads) # update state params based on grads calculated
  state = state.replace(batch_stats=updates['batch_stats']) # update state batch_stats variables

  ## Accuracy
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }

  return state, metrics

# One epoch - need to add metrics part
def train_one_epoch(state, dataloader):
  batch_metrics = []
  for cnt, (imgs, labels) in enumerate(dataloader):
    state, metrics = training_state(state, imgs, labels)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)  # pull from the accelerator onto host (CPU)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }

  return state, epoch_metrics_np

def create_train_state(key, lr, momentum):
  # Create model
  NN = NN_regularized()
  # Initialize parameters
  variables = NN.init(key, jnp.ones([1, *mnist_img_size]), train=False)
  params = variables['params']
  batch_stats_v = variables['batch_stats']
  del variables

  class TrainState_stats(train_state.TrainState):
    batch_stats: Any

  state = TrainState_stats.create(
    apply_fn=NN.apply,
    params=params,
    batch_stats=batch_stats_v,
    tx=optax.sgd(lr, momentum)
  )

  return state

In [46]:
# EVALUATION 3

# Run one evaluation on test set
@jit
def eval_step(state, imgs, gt_labels):
  logits = NN_regularized().apply({'params': state.params, 'batch_stats': state.batch_stats}, imgs, rngs={'dropout': jax.random.PRNGKey(0)}, train=False)
  one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
  loss = -jnp.mean(jnp.sum(logits * one_hot_gt_labels, axis=-1))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == gt_labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy
  }
  return metrics

def evaluate_model(state, test_imgs, test_labels):
  metrics = eval_step(state, test_imgs, test_labels)
  metrics = jax.device_get(metrics) # pull from accelerator to CPU
  metrics = jax.tree_map(lambda x: x.item(), metrics) # get scalar value from array
  return metrics

In [48]:
# FIT 3

from flax.training import train_state
seed = 0
lr = 0.01 # lower learning rate with batch norm
momentum = 0.9
n_epochs = 4

train_state = create_train_state(jax.random.PRNGKey(seed), lr, momentum)

for epoch in range(n_epochs):
  print(f'EPOCH {epoch+1}')

  train_state, train_metrics = train_one_epoch(train_state, train_loader)
  print(f'Train accuracy: {train_metrics["accuracy"]}, Train loss: {train_metrics["loss"]}')

  test_metrics = evaluate_model(train_state, test_images, test_lbls)
  print(f'Test accuracy: {test_metrics["accuracy"]}, Test loss: {test_metrics["loss"]}')
  print(' ')

EPOCH 1
Train accuracy: 0.8876368999481201, Train loss: 0.3800266981124878
Test accuracy: 0.9483999609947205, Test loss: 19.098968505859375
 
EPOCH 2
Train accuracy: 0.9444945454597473, Train loss: 0.18573138117790222
Test accuracy: 0.9564999938011169, Test loss: 17.746009826660156
 
EPOCH 3
Train accuracy: 0.9553619027137756, Train loss: 0.14886736869812012
Test accuracy: 0.9545999765396118, Test loss: 18.444652557373047
 
EPOCH 4
Train accuracy: 0.961471676826477, Train loss: 0.12642443180084229
Test accuracy: 0.9562999606132507, Test loss: 16.861549377441406
 
