# Finetuning foundation models

In this notebook we finetune ViT using different training strategies:
* Head finetuning
* Early stopping
* LoRA

**Goal.** The goal of this notebook is get practical skills in using pretrained transformer models.

You need the following extra libraries beyond PyTorch:
* torchvision
* transformers
* peft (another HuggingFace library that implements LoRA)

In [None]:
# Uncomment to install PyTorch Lightning and PEFT.
# ! pip install pytorch_lightning peft

In [None]:
import gc
import torch
import torchvision
import transformers
import pytorch_lightning as pl
from matplotlib import pyplot as plt

MODEL = "facebook/vit-mae-base"

def clean_memory():
    to_remove = set()
    for k, v in globals().items():
        if isinstance(v, (torch.nn.Module, pl.LightningModule)):
            to_remove.add(k)
    for k in to_remove:
        del globals()[k]
    gc.collect()

In [None]:
def fetch_resources():
    torchvision.datasets.CIFAR100(root="cifar100", train=True, download=True)
    model = transformers.ViTMAEModel.from_pretrained(MODEL)
    print("Num parameters:", sum([p.numel() for p in model.parameters()]))

fetch_resources()

In [None]:
class Data(pl.LightningDataModule):
    def __init__(self, num_workers=4, batch_size=16, transform=None):
        super().__init__()
        self.num_workers = num_workers
        self.batch_size = batch_size
        image_processor = transformers.AutoImageProcessor.from_pretrained(MODEL)
        self.transform = lambda x: image_processor(images=x)["pixel_values"][0]
        
    def train_dataloader(self):
        dataset = torchvision.datasets.CIFAR100(root="cifar100", train=True,
                                                transform=self.transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            drop_last=True,  # Drop the truncated last batch during training.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

    def val_dataloader(self):
        dataset = torchvision.datasets.CIFAR100(root="cifar100", train=False,
                                                transform=self.transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

def check_data():
    data_module = Data()
    x, y = next(iter(data_module.val_dataloader()))  # Val loader.
    print("Images batch:", x.shape, x.dtype)
    print("Labels batch:", y.shape, y.dtype)

check_data()

In [None]:
class Module(pl.LightningModule):
    def __init__(self, model, lr):
        super().__init__()
        self.model = model
        self.loss = torch.nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x).logits

    def configure_optimizers(self):
        # Skip freezed parameters.
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=self.lr)

    def common_step(self, batch):
        images, labels = batch
        logits = self(images)
        loss = self.loss(logits, labels)
        with torch.no_grad():
            predictions = logits.argmax(1)  # (B).
            correct = (predictions == labels).sum().item()
            accuracy = correct / labels.numel()
        return loss, accuracy

    def training_step(self, batch):
        loss, accuracy = self.common_step(batch)
        self.log("train/loss", loss, prog_bar=True)
        self.log("train/accuracy", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch):
        loss, accuracy = self.common_step(batch)
        self.log("val/loss", loss, on_epoch=True)
        self.log("val/accuracy", accuracy, on_epoch=True)

# Head fine-tuning

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
def finetune_head():
    model = transformers.ViTForImageClassification.from_pretrained(MODEL)
    for p in model.parameters():
        p.requires_grad = False
    # Add more layers for better flexibility.
    model.classifier = torch.nn.Sequential(
        torch.nn.Linear(768, 256, bias=False),
        torch.nn.BatchNorm1d(256),
        torch.nn.ReLU(inplace=True),
        torch.nn.Linear(256, 100)
    )
    # Use a typical learning rate for supervised training.
    module = Module(model, lr=1e-4)
    print("Num trainable parameters:", sum([p.numel() for p in module.parameters() if p.requires_grad]))
    trainer = pl.Trainer(max_epochs=10,
                         precision="16-mixed",
                         logger=pl.loggers.TensorBoardLogger("lightning_logs", name="HeadTuning"))
    # We can increase batch size, because the backbone doesn't require gradients and needs less memory.
    trainer.fit(module, Data(batch_size=128))

clean_memory()
finetune_head()

# Early stopping

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
def early_stopping():
    model = transformers.ViTForImageClassification.from_pretrained(MODEL)
    model.classifier = torch.nn.Sequential(
        torch.nn.Linear(768, 100)
    )
    module = Module(model, lr=2e-5)
    print("Num trainable parameters:", sum([p.numel() for p in module.parameters() if p.requires_grad]))
    trainer = pl.Trainer(max_epochs=20,
                         precision="16-mixed",
                         logger=pl.loggers.TensorBoardLogger("lightning_logs", name="EarlyStopping Long"))
    trainer.fit(module, Data())

clean_memory()
early_stopping()

# LoRA finetuning

Train a low-rank weight update. The transformers library provides necessary tools.

LoRA updates the weights tensor as:

```
scaling = alpha / r
weight += (lora_B @ lora_A) * scaling 
```

The main parameters are:
* the scaling weight ```alpha```
* the factorization rank ```r```
* the list of modules to update


In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

In [None]:
import peft

def finetune_lora():
    model = transformers.ViTForImageClassification.from_pretrained(MODEL)
    model.classifier = torch.nn.Sequential(
        torch.nn.Linear(768, 100)
    )
    print(model)
    config = peft.LoraConfig(r=128,
                             lora_alpha=256,
                             target_modules=["query", "value"],  # Modules to train with LoRA.
                             modules_to_save=["classifier"])  # Modules to train without LoRA.
    model = peft.get_peft_model(model, config)
    model.print_trainable_parameters()
    module = Module(model, lr=2e-5)
    print("Num trainable parameters:", sum([p.numel() for p in module.parameters() if p.requires_grad]))
    trainer = pl.Trainer(max_epochs=2,
                         precision="16-mixed",
                         logger=pl.loggers.TensorBoardLogger("lightning_logs", name="LoRA (lr 2e-5, QV, r128)"))
    trainer.fit(module, Data())
    model = model.merge_and_unload()

clean_memory()
finetune_lora()