In [1]:
import jax
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
import flax.linen as nn
import jax.numpy as jnp
from typing import Sequence
import optax
import numpy as np
from flax.training import train_state, checkpoints, early_stopping
from flax.metrics import tensorboard

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307), (0.3081)),
    transforms.Lambda(lambda x: torch.flatten(x))
])

In [11]:
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x


In [12]:
@jax.jit
def apply_model(state, data, labels):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, data)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

In [None]:
def kl_divergence(output, y):
    return -jnp.mean(y * output - (1.0 - y) * (jnp.exp(output - 1.0)), axis=0)[0]


In [None]:
@jax.jit
def apply_student_model(state, data, labels):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, data)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = 0.5 * jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) + 0.5 * 

In [13]:
@jax.jit
def update_model(state, grads): 
    return state.apply_gradients(grads=grads)

In [14]:
def train_epoch(state, train_dt, rng):
    epoch_loss = []
    epoch_accuracy = []

    for batch_idx, (data, target) in enumerate(train_dt):
        data, target = data.numpy(), target.numpy()
        data, target = jnp.float32(data), jnp.float32(target)

        grads, loss, accuracy = apply_model(state, data, target)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)

    return state, train_loss, train_accuracy

In [15]:
def create_train_state(rng, config):
    mlp = MLP([1200, 1200, 10])
    params = mlp.init(rng, jnp.ones([1, 784]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=mlp.apply, params=params, tx=tx
    )

In [31]:
def train_and_evaluate(config, workdir):
    train = MNIST(root='data/', train=True, download=True, transform=transform)
    test = MNIST(root='data/', train=False, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(test, **test_kwargs)

    rng = jax.random.PRNGKey(0)

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    early_stop = early_stopping.EarlyStopping(min_delta=1e-3, patience=2)
    best_score = 0

    for epoch in range(1, config['num_epoch'] + 1):
        rng, init_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(state, train_loader, rng)
        _, early_stop = early_stop.update(train_loss)
        test_dt, test_labels = next(iter(test_loader))
        _, test_loss, test_accuracy = apply_model(state, test_dt.numpy(), test_labels.numpy())
        print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
        % (epoch, train_loss, train_accuracy * 100, test_loss,
           test_accuracy * 100))
        
        summary_writer.scalar('train_loss', train_loss, epoch)
        summary_writer.scalar('train_accuracy', train_accuracy, epoch)
        summary_writer.scalar('test_loss', test_loss, epoch)
        summary_writer.scalar('test_accuracy', test_accuracy, epoch)
        
        if test_accuracy > best_score:
            checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, overwrite=True, step=epoch, prefix='mlp_1200_', )

        if early_stop.should_stop:
            print('Met early stopping criteria, breaking...')
            break

    summary_writer.flush()
    return state
    

In [32]:
config = {'num_epoch': 10}
train_kwargs = {'batch_size': 8}
test_kwargs = {'batch_size': 120}
learning_rate = 1e-3
num_epoch = 1
CKPT_DIR = 'ckpts'

In [34]:
state = train_and_evaluate(config, 'logs/')

epoch:  1, train_loss: 0.2349, train_accuracy: 93.36, test_loss: 0.1829, test_accuracy: 93.33
epoch:  2, train_loss: 0.1368, train_accuracy: 96.31, test_loss: 0.2224, test_accuracy: 93.33
epoch:  3, train_loss: 0.1088, train_accuracy: 97.09, test_loss: 0.0729, test_accuracy: 96.67
epoch:  4, train_loss: 0.0941, train_accuracy: 97.57, test_loss: 0.0394, test_accuracy: 99.17
epoch:  5, train_loss: 0.0896, train_accuracy: 97.72, test_loss: 0.1100, test_accuracy: 99.17
epoch:  6, train_loss: 0.0763, train_accuracy: 98.07, test_loss: 0.0374, test_accuracy: 99.17
epoch:  7, train_loss: 0.0762, train_accuracy: 98.14, test_loss: 0.0400, test_accuracy: 97.50
epoch:  8, train_loss: 0.0753, train_accuracy: 98.22, test_loss: 0.0477, test_accuracy: 97.50
epoch:  9, train_loss: 0.0675, train_accuracy: 98.34, test_loss: 0.0168, test_accuracy: 99.17
epoch: 10, train_loss: 0.0653, train_accuracy: 98.49, test_loss: 0.0317, test_accuracy: 98.33


## Train student using distillation