# transformers: Finetuning

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from lightning.pytorch import seed_everything
from transformers.image_utils import load_image
from transformers import AutoImageProcessor, AutoModelForImageClassification

from hf_utils import (
    CIFAR10DataModule,
    LightningHFImageClassif,
    LightningHFImageClassifLoRA
)

In [None]:
# set random seeds
_ = seed_everything(123)

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

image = load_image(url)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Load model

In [None]:
# set model name
# model_name = 'microsoft/resnet-18'
# model_name = 'google/vit-base-patch16-224'
# model_name = 'facebook/dinov2-small'
model_name = 'facebook/dinov2-small-imagenet1k-1-layer'

In [None]:
# create preprocessor
processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

# initialize model
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    device_map='auto',
    # num_labels=len(label_names),  # set number of target labels
    # id2label={idx: label for idx, label in enumerate(label_names)},
    # label2id={label: idx for idx, label in enumerate(label_names)},
    # ignore_mismatched_sizes=True
)
model = model.eval()

print(f'Model device: {model.device}')
print(f'Model dtype: {model.dtype}')
print(f'Memory footprint: {model.get_memory_footprint() * 1e-9:.2f} GiB')

## Run model

In [None]:
# preprocess images
preprocessed_images = processor(image, return_tensors='pt')
x = preprocessed_images['pixel_values']

# run model
with torch.inference_mode():
    outputs = model(**preprocessed_images.to(model.device))

logits = outputs.logits.cpu()

print(f'Images shape: {x.shape}')
print(f'Logits shape: {logits.shape}')

In [None]:
# get predicted labels
label_ids = logits.argmax(dim=-1)
labels = [model.config.id2label[lidx.item()] for lidx in label_ids]

print(labels)

## Load data

In [None]:
# import data
cifar = CIFAR10DataModule(
    data_dir='../run/data/',
    img_size=(224, 224),
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    batch_size=32,
    num_workers=0
)

cifar.prepare_data()  # download data if not yet done
cifar.setup(stage='test')  # create test set

In [None]:
# get batch
test_loader = cifar.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    label = cifar.label_names[y_batch[idx].item()]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Load model

In [None]:
# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# initialize model
model = LightningHFImageClassifLoRA(
    model_name,
    data_dir=None,
    num_labels=None,
    lr=1e-04,
    lr_schedule='constant',
    lr_interval='epoch',
    lr_warmup=0,
    lr_cycles=1,
    freeze_backbone=True,
    lora_rank=16,
    lora_alpha=None,
    lora_dropout=None,
    lora_bias='none',
    lora_target_modules=['query', 'value'],  # specify layers to apply LoRA (linear, conv, MHA, etc.)
)

model = model.eval()
model = model.to(device)

In [None]:
# load model from checkpoint
ckpt_file = f'../run/finetune/version_0/checkpoints/last.ckpt'
# ckpt_file = f'../run/lora/version_0/checkpoints/last.ckpt'

if Path(ckpt_file).is_file():
    model = LightningHFImageClassif.load_from_checkpoint(ckpt_file, map_location=None)
    # model = LightningHFImageClassifLoRA.load_from_checkpoint(ckpt_file, map_location=None)

    model = model.eval()
    model = model.to(device)

## Run model

In [None]:
# get batch of data
test_loader = cifar.test_dataloader()
batch = next(iter(test_loader))

x_batch = batch['pixel_values']
y_batch = batch['labels']

In [None]:
# run model
with torch.inference_mode():
    y_logits = model(x_batch.to(model.device)).cpu()

print(f'Images shape: {x_batch.shape}')
print(f'Logits shape: {y_logits.shape}')

In [None]:
# get predicted labels
label_ids = y_logits.argmax(dim=-1)
labels = [model.model.config.id2label[lidx.item()] for lidx in label_ids]

print(labels)

In [None]:
# show predictions
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = cifar.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    label = labels[idx]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Predictions')
fig.tight_layout()