In [1]:
import jax
from jax import numpy as jnp
from jax import grad, jit, value_and_grad
from flax import linen as nn
from typing import Sequence

# Create PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)

# Create model

In [2]:
class MLP(nn.Module):
    layer_sizes: Sequence[int] = None
    
    def setup(self):
        self.layers = [nn.Dense(features=size) for size in self.layer_sizes[1:]]
    
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
            x = nn.relu(x)
        return self.layers[-1](x)


In [3]:
layer_sizes = [784, 512, 512, 10]

# Create model
model = MLP(layer_sizes)

# Using `init` and dummy_x to create model parameters
key, init_key = jax.random.split(key)  # init_key used for initialization
dummy_x = jax.random.uniform(init_key, (784, ))
key, init_key = jax.random.split(key)

params = model.init(init_key, dummy_x)

# params

# Optimizer and learning rate scheduler methods

In [4]:
import optax

lr = 1e-3
lr_decay_fn = optax.linear_schedule(
        init_value=lr,
        end_value=1e-5,
        transition_steps=200,
)

optimizer = optax.adam(
            learning_rate=lr_decay_fn,
)

In [5]:
# Creat a random batch data, shape=(32, 784)
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))

model.apply(params, random_batched_flattened_images).shape

(32, 10)

# TrainState

Encapsulate the state of the training process into one class and manage it uniformly

In [6]:
from flax.training import train_state

In [7]:
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

# Data loading

In [8]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Sampler, SequentialSampler
from torchvision.datasets import MNIST


class FlattenAndCast(object):  
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))


# DataLoader returns numpy array，not torch Tensor
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class JAXRandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())

In [9]:
class NumpyLoader(DataLoader):
    def __init__(self, dataset, rng_key=None, batch_size=1,
                 shuffle=False, **kwargs):
        if shuffle:
            sampler = JAXRandomSampler(dataset, rng_key)
        else:
            sampler = SequentialSampler(dataset)
        
        super().__init__(dataset, batch_size, sampler=sampler, **kwargs)

In [10]:
# With the help of torchvision and NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=128, shuffle=True,
                           num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,
                          collate_fn=numpy_collate, drop_last=False)

# Training process

In [11]:
def train_step(state, x, y):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn(params, x)
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss

    grad_fn = value_and_grad(loss_fn)  # `value_and_grad` return the loss while performing a grad 
    loss, grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

jit_train_step = jit(train_step, donate_argnums=(0,))  # donate_argnums is used for buffer reuse, in which case the buffers for input and output states are reused


@jax.jit
def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn(state.params, x)
    return jnp.argmax(logits, -1)


In [12]:
def eval_model(state, loader):
    total_acc = 0.
    total_num = 0.
    for x, y in loader:
        y_pred = apply_model(state, x)
        total_num += len(x)
        total_acc += jnp.sum(y_pred == y)
    return total_acc / total_num

In [13]:
for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        state, loss = jit_train_step(state, x, y)
        if idx % 100 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, loss, train_acc, eval_acc))

Epoch 0 - batch_idx 0, loss 59.9576530456543, Training set acc 0.28944647312164307, eval set accuracy 0.2947999835014343
Epoch 0 - batch_idx 100, loss 0.8403108716011047, Training set acc 0.9234442114830017, eval set accuracy 0.9197999835014343
Epoch 0 - batch_idx 200, loss 0.11547716706991196, Training set acc 0.938969075679779, eval set accuracy 0.9287999868392944
Epoch 0 - batch_idx 300, loss 1.2561315298080444, Training set acc 0.9425748586654663, eval set accuracy 0.9311999678611755
Epoch 0 - batch_idx 400, loss 0.6336281895637512, Training set acc 0.9441940784454346, eval set accuracy 0.9339999556541443
Epoch 1 - batch_idx 0, loss 0.5855568647384644, Training set acc 0.9455128908157349, eval set accuracy 0.9337999820709229
Epoch 1 - batch_idx 100, loss 0.5183431506156921, Training set acc 0.9467315077781677, eval set accuracy 0.9347999691963196
Epoch 1 - batch_idx 200, loss 0.22533944249153137, Training set acc 0.9475494623184204, eval set accuracy 0.9355999827384949
Epoch 1 - ba