# Contrastive learning on MNIST

An example of contrastive learning applied to MNIST is provided in this notebook. The model can be trained through `python scripts/main.py fit --config config/mnist.yaml`. Meanwhile one may want to monitor the experiment by `tensorboard --logdir run/`. After the training has been completed, the final model is imported and analyzed in the following.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import seed_everything

from contralearn import (
    MNISTDataModule,
    ConvEmbedding,
    embed_loader
)

In [None]:
_ = seed_everything(111111)  # set random seeds manually

## MNIST data

In [None]:
mnist = MNISTDataModule(
    data_dir='../run/data/',
    mean=None,
    std=None,
    batch_size=32
)

mnist.prepare_data()  # download data if not yet done
mnist.setup(stage='test')  # create test set

In [None]:
test_loader = mnist.test_dataloader()
x_batch, y_batch = next(iter(test_loader))

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(5, 4.5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx, 0].numpy()
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set_title(mnist.test_set.classes[y_batch[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Two-dim. embedding

In [None]:
ckpt_file = '../run/mnist/version_0/checkpoints/best.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

emb = ConvEmbedding.load_from_checkpoint(ckpt_file)

emb = emb.eval()
emb = emb.to(device)

In [None]:
embeddings, labels = embed_loader(
    emb,
    test_loader,
    return_labels=True
)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
for idx in range(10):
    ax.scatter(
        embeddings[labels==idx, 0][::2].numpy(),
        embeddings[labels==idx, 1][::2].numpy(),
        color=plt.cm.tab10(idx),
        alpha=0.3,
        edgecolors='none',
        label='{}'.format(idx)
    )
ax.set_aspect('equal', adjustable='datalim')
ax.legend(loc='center left')
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()