In [1]:
import jax
from jax import numpy as jnp
from jax import grad, jit, value_and_grad
import flax
from flax import linen as nn
from typing import Sequence, Any

# Create PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)

# Create model

In [2]:
class MLP(nn.Module):
    
    def setup(self):
        self.layer1 = nn.Dense(features=512)
        self.dropout1 = nn.Dropout(rate=0.3)
        self.norm1 = nn.BatchNorm()
        
        self.layer2 = nn.Dense(features=512)
        self.dropout2 = nn.Dropout(rate=0.4)
        self.norm2 = nn.BatchNorm()
        
        self.layer3 = nn.Dense(features=10)
        
    
    def __call__(self, x, train:bool = True):
        x = nn.relu(self.layer1(x))
        x = self.dropout1(x, deterministic=not train)
        x = self.norm1(x, use_running_average=not train)
        x = nn.relu(self.layer2(x))
        x = self.dropout2(x, deterministic=not train)
        x = self.norm2(x, use_running_average=not train)
        
        x = self.layer3(x)

        return x

In [3]:
# Create model
model = MLP()

# Using `init` and dummy_x to generate model parameters
key, init_key = jax.random.split(key)
dummy_x = jax.random.uniform(init_key, (784, ))

key, init_key, drop_key = jax.random.split(key, 3)

variables = model.init({"params": init_key, "dropout": drop_key}, dummy_x, train=True)

In [4]:
variables.keys()

frozen_dict_keys(['params', 'batch_stats'])

In [5]:
variables['batch_stats'].keys()

frozen_dict_keys(['norm1', 'norm2'])

In [6]:
variables['batch_stats']['norm1'].keys()

frozen_dict_keys(['mean', 'var'])

In [7]:
key, drop_key = jax.random.split(key)

In [8]:
y, non_trainable_params = model.apply(variables, dummy_x, train=True, rngs={"dropout": drop_key},
                                      mutable=['batch_stats']) 

In [9]:
y

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [10]:
non_trainable_params.keys()

frozen_dict_keys(['batch_stats'])

# Optimizer and learning rate scheduler methods

In [12]:
import optax

lr = 1e-3
lr_decay_fn = optax.linear_schedule(
        init_value=lr,
        end_value=1e-5,
        transition_steps=200,
)

optimizer = optax.adam(
            learning_rate=lr_decay_fn,
)

# TrainState

Encapsulate the state of the training process into one class and manage it uniformly

In [13]:
from flax.training import train_state

In [14]:
class CustomTrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict[str, Any]

# state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=optimizer,
#                                      batch_stats=variables["batch_stats"])

state = CustomTrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer,
    batch_stats=variables['batch_stats'],
)

# Data loading

In [15]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Sampler, SequentialSampler
from torchvision.datasets import MNIST


class FlattenAndCast(object):  
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))


# DataLoader returns numpy array, not torch Tensor
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class JAXRandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())

In [16]:
class NumpyLoader(DataLoader):
    def __init__(self, dataset, rng_key=None, batch_size=1,
                 shuffle=False, **kwargs):
        if shuffle:
            sampler = JAXRandomSampler(dataset, rng_key)
        else:
            sampler = SequentialSampler(dataset)
        
        super().__init__(dataset, batch_size, sampler=sampler, **kwargs)

In [17]:
# With the help of torchvision and NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=128, shuffle=True,
                           num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,
                          collate_fn=numpy_collate, drop_last=False)

# Training process

In [18]:
def train_step(state, x, y, dropout_key):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
                                           x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
        
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, new_state

    grad_fn = value_and_grad(loss_fn, has_aux=True)  # `value_and_grad` return the loss while performing a grad 
    (loss, new_state), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads, batch_stats=new_state["batch_stats"])
    return new_state, loss

jit_train_step = jit(train_step, donate_argnums=(0,))  # donate_argnums is used for buffer reuse, in which case the buffers for input and output states are reused


@jax.jit
def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
                            x, train=False)
    return jnp.argmax(logits, -1)


In [19]:
def eval_model(state, loader):
    total_acc = 0.
    total_num = 0.
    for x, y in loader:
        y_pred = apply_model(state, x)
        total_num += len(x)
        total_acc += jnp.sum(y_pred == y)
    return total_acc / total_num

In [20]:
for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        key, dropout_key = jax.random.split(key)
        state, loss = jit_train_step(state, x, y, dropout_key)
        if idx % 100 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, loss, train_acc, eval_acc))

Epoch 0 - batch_idx 0, loss 2.559518337249756, Training set acc 0.3179420530796051, eval set accuracy 0.31070002913475037
Epoch 0 - batch_idx 100, loss 0.3981797695159912, Training set acc 0.9382011890411377, eval set accuracy 0.9367000460624695
Epoch 0 - batch_idx 200, loss 0.29799991846084595, Training set acc 0.9520065784454346, eval set accuracy 0.9492000341415405
Epoch 0 - batch_idx 300, loss 0.22030052542686462, Training set acc 0.9536759257316589, eval set accuracy 0.9513000249862671
Epoch 0 - batch_idx 400, loss 0.22531506419181824, Training set acc 0.9540432095527649, eval set accuracy 0.950700044631958
Epoch 1 - batch_idx 0, loss 0.2441655695438385, Training set acc 0.954594075679779, eval set accuracy 0.9508000612258911
Epoch 1 - batch_idx 100, loss 0.14692620933055878, Training set acc 0.9552618265151978, eval set accuracy 0.9508000612258911
Epoch 1 - batch_idx 200, loss 0.10268256068229675, Training set acc 0.9557124972343445, eval set accuracy 0.9513000249862671
Epoch 1 -