# transformers: Transfer learning

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import seed_everything

from hf_utils import (
    CIFAR10DataModule,
    LightningImgClassif
)

In [None]:
# set random seeds manually
_ = seed_everything(123)

## Load data

In [None]:
# import data
cifar = CIFAR10DataModule(
    data_dir='../run/data/',
    img_size=(224, 224),
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    batch_size=32,
    num_workers=0
)

cifar.prepare_data() # download data if not yet done
cifar.setup(stage='test') # create test set

In [None]:
# get batch
test_loader = cifar.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    label = cifar.label_names[y_batch[idx].item()]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Load model

In [None]:
# load model from checkpoint
ckpt_file = f'../run/transfer/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = LightningImgClassif.load_from_checkpoint(ckpt_file)

model = model.eval()
model = model.to(device)

## Run model

In [None]:
# get batch of data
test_loader = cifar.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# run model
with torch.no_grad():
    y_logits = model(x_batch)

print(f'Images shape: {x_batch.shape}')
print(f'Logits shape: {y_logits.shape}')

In [None]:
# get predicted labels
label_ids = y_logits.argmax(dim=-1)
labels = [cifar.id2label[lidx.item()] for lidx in label_ids]

print(labels)

In [None]:
# show predictions
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    label = labels[idx]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Predictions')
fig.tight_layout()