In [1]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

from ml_collections import ConfigDict

import jax
import numpy as np
import jax.numpy as jnp

import optax
from flax import linen as nn
from flax.training import train_state

from rich.progress import Progress
from rich.jupyter import print

from plotly.subplots import make_subplots
import plotly.graph_objs as go

from IPython.display import display

In [2]:
jax.config.update('jax_platform_name', 'cpu')

## Config

In [3]:
config = ConfigDict()
config.seed = 0
config.epochs = 2
config.batch_size = 100
config.learning_rate = 1e-4

## Data

In [4]:
class ToNumPy:
  def __call__(self, pic):
    return np.array(pic, dtype=jnp.float32)

In [5]:
def one_hot(x, k=10, dtype=jnp.float32):
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
# one_hot = jax.nn.one_hot(labels, 10)

In [6]:
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)
      
def dict_collate(batch):
    input, target = numpy_collate(batch)
    target_one_hot = one_hot(target)
    input = input.reshape(-1,28,28,1)
    return {"input": input, "target": target_one_hot, "target_nums": target}

In [7]:
def get_datasets(batch_size:int, shaffle: bool = False):
    ds_train = MNIST(root=".", train=True, download=True, transform=ToNumPy())
    ds_test = MNIST(root=".", train=False, download=True, transform=ToNumPy())
    loader_train = DataLoader(ds_train, batch_size=batch_size, shuffle=shaffle, collate_fn=dict_collate)
    loader_test = DataLoader(ds_test, batch_size=batch_size, shuffle=shaffle, collate_fn=dict_collate)
    return loader_train, loader_test

## Model

In [8]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [9]:
def create_train_state(rng, config):
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.adam(config.learning_rate)
  return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)

In [10]:
@jax.jit
def apply_model(state, images, labels):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=labels))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))*100
    return grads, loss, accuracy

In [11]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

## Training

### Visualization

In [12]:
class Monitor():
    def __init__(self, config, train_data):
        self.train_loss = []
        self.train_acc = []
        self.eval_loss = []
        self.eval_acc = []

        self.steps_per_epoch = len(train_data)
        self.steps_total = self.steps_per_epoch*config.epochs

        self.x_range = [1, self.steps_total]

        self._create_fig()
        self._create_tbl()

    def _create_fig(self, use_y_log: bool = True):
        subplots = make_subplots(rows=1, cols=2)
        subplots.add_scatter(x=[], y=[], row=1, col=1, name="loss")
        subplots.add_scatter(x=[], y=[], row=1, col=2, name="accuracy")
        
        self.fig = go.FigureWidget(subplots)
        
        self.fig.layout.title="Training"
        self.fig.update_xaxes(title_text="Step", range=self.x_range, row=1, col=1)
        self.fig.update_xaxes(title_text="Step", range=self.x_range, row=1, col=2)
        if use_y_log:
            self.fig.update_yaxes(title_text="Loss", type="log", row=1, col=1)
            self.fig.update_yaxes(title_text="Accuracy, %", type="log", row=1, col=2)

    def _create_tbl(self):
        header = dict(values=['Epoch', 'Accuracy,%'], fill_color='paleturquoise', align='left')
        cells = dict(values=[[],[]], fill_color='lavender', align='left')
        table = go.Table(header=header, cells=cells)
        
        self.table = go.FigureWidget(data=[table])
        self.table.layout.update(title="Evaluation")

    def update_table(self):
        n = len(self.eval_acc)
        epoch_idx_list = [x for x in range(1, n+1)]
        self.table.data[0].cells.values = [epoch_idx_list, self.eval_acc]

    def update_fig(self):
        x_range = [x for x in range(1, len(self.train_loss)+1)]
        with self.fig.batch_update():
            scatt_loss = self.fig.data[0]
            scatt_acc = self.fig.data[1]
            scatt_loss.x = x_range
            scatt_loss.y = tuple(self.train_loss)
            scatt_acc.x = x_range
            scatt_acc.y = tuple(self.train_acc)

    def _draw_line(self, step):
        self.fig.add_vline(x=step, line_width=3, line_dash="dash", line_color="green")

    def add_eval_acc(self, val):
        self.eval_acc.append(val)
        self.update_table()

    def add_train_loss_acc(self, loss, acc):
        self.train_loss.append(loss)
        self.train_acc.append(acc)
        self.update_fig()

        curr_step = len(self.train_loss)
        if curr_step % self.steps_per_epoch == 0:
            self._draw_line(curr_step)
    
    def show(self):
        display(self.fig)
        display(self.table)

### Main procedures

In [13]:
def train_epoch(state, loader_train, epoch, monitor = None, progress = None):
    if progress:
        loader_train = progress.track(loader_train, description=f"Train {epoch:02d}.")
    
    for data in loader_train:
        grads, loss, accuracy = apply_model(state, data["input"], data["target"])
        state = update_model(state, grads)
        if monitor is not None:
            monitor.add_train_loss_acc(loss, accuracy)

    return state

In [14]:
def eval_epoch(state, loader_test, epoch, monitor = None, progress = None):
    if progress:
        loader_test = progress.track(loader_test, description=f" Eval {epoch:02d}.")
    accuracy_list = []
    for data in loader_test:
        _, _, accuracy = apply_model(state, data["input"], data["target"])
        accuracy_list.append(accuracy)
    if monitor is not None:
        acc_avg = np.mean(accuracy_list)
        monitor.add_eval_acc(acc_avg)

### Run

In [15]:
loader_train, loader_test = get_datasets(config.batch_size)

In [16]:
rng = jax.random.PRNGKey(config.seed)
rng, init_rng = jax.random.split(rng)

state = create_train_state(init_rng, config)

Metal device set to: Apple M1 Pro


In [17]:
monitor = Monitor(config, loader_train)
monitor.show()

FigureWidget({
    'data': [{'name': 'loss',
              'type': 'scatter',
              'uid': 'b8c38684-a222-4951-aa4e-dded8c81e75e',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'name': 'accuracy',
              'type': 'scatter',
              'uid': '3e195307-4323-465e-b258-bd0fb843f008',
              'x': [],
              'xaxis': 'x2',
              'y': [],
              'yaxis': 'y2'}],
    'layout': {'template': '...',
               'title': {'text': 'Training'},
               'xaxis': {'anchor': 'y', 'domain': [0.0, 0.45], 'range': [1, 1200], 'title': {'text': 'Step'}},
               'xaxis2': {'anchor': 'y2', 'domain': [0.55, 1.0], 'range': [1, 1200], 'title': {'text': 'Step'}},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Loss'}, 'type': 'log'},
               'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'title': {'text': 'Accuracy, %'}, 'type': 'log'}}
}

FigureWidget({
    'data': [{'cells': {'align': 'left', 'fill': {'color': 'lavender'}, 'values': [[], []]},
              'header': {'align': 'left', 'fill': {'color': 'paleturquoise'}, 'values': ['Epoch', 'Accuracy,%']},
              'type': 'table',
              'uid': 'a0fe0cda-8d46-487e-8b11-fae6deb21e54'}],
    'layout': {'template': '...', 'title': {'text': 'Evaluation'}}
})

In [18]:
with Progress() as progress:
    for epoch in range(1, config.epochs+1):
        state = train_epoch(state, loader_train, epoch, monitor, progress)
        eval_epoch(state, loader_test, epoch, monitor, progress)

Output()

In [19]:
%load_ext watermark
%watermark -d -u -v -iv

Last updated: 2023-07-29

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.14.0

flax  : 0.7.0
jax   : 0.4.11
numpy : 1.25.1
optax : 0.1.5
plotly: 5.15.0

