# Finetuning

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification
)

from utils import CIFAR10DataModule

## Import model 

In [None]:
# set model name
ckpt_name = 'google/vit-base-patch16-224'

In [None]:
# create preprocessor
processor = AutoImageProcessor.from_pretrained(ckpt_name)

print(processor)

In [None]:
# load model (trained on a different dataset)
model = AutoModelForImageClassification.from_pretrained(
    ckpt_name,
    num_labels=10,
    ignore_mismatched_sizes=True
)

model = model.eval()

print(f'Number of parameters: {model.num_parameters()}')

## Import data

In [None]:
# create transforms
size = (processor.size['height'], processor.size['width'])
mean = processor.image_mean
std = processor.image_std

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

val_transform = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transform = val_transform

In [None]:
# import data
cifar10 = CIFAR10DataModule(
    data_dir=None,
    transform=None,
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
    random_state=42,
    batch_size=32,
    num_workers=0
)

cifar10.prepare_data() # download data if not yet done
cifar10.setup(stage='test') # create test set

In [None]:
# get batch
test_loader = cifar10.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['label']

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx].permute(1, 2, 0).numpy() / 2 + 0.5
    label = cifar10.label_names[y_batch[idx].item()]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Run model

In [None]:
# get batch of data
test_loader = cifar10.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['label']

In [None]:
# run model
with torch.no_grad():
    outputs = model(x_batch)

y_logits = outputs.logits

print(f'Images shape: {x_batch.shape}')
print(f'Logits shape: {y_logits.shape}')

In [None]:
# get predicted labels
y_idx = y_logits.argmax(dim=-1)
y_label = [cifar10.id2label[label.item()] for label in y_idx]

print(y_label)

In [None]:
# show predictions
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx].permute(1, 2, 0).numpy() / 2 + 0.5
    label = y_label[idx]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Predictions')
fig.tight_layout()