# transformers: Transfer learning

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import seed_everything
from hf_models import (
    CIFAR10DataModule,
    LightningImageClassifier
)

In [None]:
# set random seeds manually
_ = seed_everything(123)

## Import data

In [None]:
# import data
cifar10 = CIFAR10DataModule(
    cache_dir='../run/data',
    img_size=224,
    img_mean=(0.5, 0.5, 0.5),
    img_std=(0.5, 0.5, 0.5),
    batch_size=32,
    num_workers=0
)

cifar10.prepare_data() # download data if not yet done
cifar10.setup(stage='test') # create test set

In [None]:
# get batch
test_loader = cifar10.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx].permute(1, 2, 0).numpy() / 2 + 0.5
    label = cifar10.label_names[y_batch[idx].item()]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Load model

In [None]:
# load model from checkpoint
ckpt_file = f'../run/transfer/version_0/checkpoints/step=2.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = LightningImageClassifier.load_from_checkpoint(ckpt_file)

model = model.eval()
model = model.to(device)

## Run model

In [None]:
# get batch of data
test_loader = cifar10.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# run model
with torch.no_grad():
    y_logits = model(x_batch)

print(f'Images shape: {x_batch.shape}')
print(f'Logits shape: {y_logits.shape}')

In [None]:
# get predicted labels
y_idx = y_logits.argmax(dim=-1)
y_label = [cifar10.id2label[label.item()] for label in y_idx]

print(y_label)

In [None]:
# show predictions
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx].permute(1, 2, 0).numpy() / 2 + 0.5
    label = y_label[idx]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Predictions')
fig.tight_layout()