# Image classification

In [None]:
!pip install torch==1.11.*
!pip install torchvision==0.12.*
!pip install fairscale==0.4.*
!pip install einops==0.4.*
!pip install tokenizers==0.12.*
!pip install jsonargparse==4.7.1
!pip install pytorch-lightning==1.6.2

In [None]:
!git clone https://github.com/krasserm/perceiver-io.git -b wip-enhancements

In [None]:
# Download model checkpoints
!wget -nc -O logs.zip https://martin-krasser.com/perceiver/logs-update-3.zip
!unzip -qo logs.zip

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

sys.path.append("perceiver-io")

from torchvision.datasets import MNIST
from perceiver.data.mnist_preproc import MnistPreprocessor
from perceiver.model.lightning import LitImageClassifier

In [None]:
mnist = MNIST(root='.cache', download=True, train=False)
mnist_preproc = MnistPreprocessor()

In [None]:
ckpt_path = 'logs/img_clf/version_0/checkpoints/epoch=018-val_loss=0.092.ckpt'

model = LitImageClassifier.load_from_checkpoint(ckpt_path).model
model.eval();

In [None]:
cols, rows = 3, 3
imgs = [mnist[i][0] for i in range(cols * rows)]

In [None]:
with torch.no_grad():
    logits = model(mnist_preproc.preprocess_batch(imgs))
    preds = logits.argmax(dim=1)

In [None]:
plt.figure(figsize=(8, 8))
for i, (img, pred) in enumerate(zip(imgs, preds)):
    plt.subplot(rows, cols, i + 1)
    plt.axis('off')
    plt.title(f'Prediction: {pred}')
    plt.imshow(np.array(img), cmap='gray')