# transformers: PEFT with LoRA

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor
)
import evaluate
from datasets import load_dataset
from transformers import (
    set_seed,
    AutoImageProcessor,
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer
)
from peft import (
    LoraConfig,
    LoraModel,
    get_peft_model
)
from peft.tuners.lora.layer import LoraLayer

In [None]:
# set random seed
set_seed(123)

## Load data

In [None]:
# load data
ds = load_dataset('food101', split='train[:5000]')
ds = ds.train_test_split(test_size=0.2)

print(ds)

In [None]:
# get label names
label_names = ds['train'].features['label'].names

print(label_names)

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(6, 5))
random_ids = np.random.choice(len(ds['train']), size=axes.size, replace=False).tolist()
for random_idx, ax in zip(random_ids, axes.ravel()):
    pil_image = ds['train'][random_idx]['image']
    label_idx = ds['train'][random_idx]['label']
    ax.imshow(np.asarray(pil_image))
    ax.set_title(label_names[label_idx])
    ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
fig.tight_layout()

## Initialize model

In [None]:
# set model name
model_name = 'google/vit-base-patch16-224'
# model_name = 'facebook/dinov2-small-imagenet1k-1-layer'

In [None]:
# set device
device_map = 'cuda:0' if torch.cuda.is_available() else 'cpu'

device = torch.device(device_map)

In [None]:
# create preprocessor
processor = AutoImageProcessor.from_pretrained(model_name)

# initialize model
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    device_map=device_map,
    num_labels=len(label_names),  # set number of target labels
    id2label={idx: label for idx, label in enumerate(label_names)},
    label2id={label: idx for idx, label in enumerate(label_names)},
    ignore_mismatched_sizes=True
)
model = model.eval()

print('Model device: {}'.format(model.device))
print('Model dtype: {}'.format(model.dtype))
print('Memory footprint: {:.2f} GiB'.format(model.get_memory_footprint() * 1e-9))

In [None]:
# create LoRA config
config = LoraConfig(
    r=16,
    lora_alpha=16.0,
    lora_dropout=0.1,
    init_lora_weights=True,
    bias='none',
    target_modules=['query', 'value'],  # specify layers to apply LoRA (linear, conv2d, MHA, etc.)
    modules_to_save=['classifier']  # specify layers to unfreeze and update
)

In [None]:
# create LoRA model
model = get_peft_model(model, config)  # this works too with other (non-transformers) PyTorch models

model.print_trainable_parameters()

In [None]:
# check LoRA model
is_lora_model = isinstance(model.base_model, LoraModel)
has_lora_layers = any([isinstance(m, LoraLayer) for m in model.modules()])

if is_lora_model and has_lora_layers:
    print('LoRA model is correctly initialized')
else:
    print('LoRA model not correctly initialized')

## Set up training

In [None]:
# create transforms
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)

transform_train = Compose([
    RandomResizedCrop(processor.size['height']),
    RandomHorizontalFlip(),
    ToTensor(),
    normalize
])

transform_val = Compose([
    Resize(processor.size['height']),
    CenterCrop(processor.size['height']),
    ToTensor(),
    normalize
])

def preprocess_train(batch):
    '''Apply train transform across batch of images.'''
    batch['pixel_values'] = [transform_train(img.convert('RGB')) for img in batch['image']]
    return batch

def preprocess_val(batch):
    '''Apply val transform across batch of images.'''
    batch['pixel_values'] = [transform_val(img.convert('RGB')) for img in batch['image']]
    return batch

In [None]:
# create train and val. datasets
ds_train = ds['train']
ds_val = ds['test']

# set transforms
ds_train.set_transform(preprocess_train)
ds_val.set_transform(preprocess_val)

In [None]:
# create collate function (for assembling batches)
def collate_fn(examples):
    pixel_values = torch.stack([example['pixel_values'] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

In [None]:
# create evaluation function
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    '''Compute accuracy on a batch of predictions.'''
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
# set training args
training_args = TrainingArguments(
    output_dir=f'../run/{model_name.split('/')[-1]}-finetuned-lora-food101',
    overwrite_output_dir=False,
    remove_unused_columns=False,
    seed=42,
    use_cpu=(device_map == 'cpu'),
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    learning_rate=5e-3,
    optim='adamw_torch',
    weight_decay=0.,
    dataloader_drop_last=True,
    dataloader_num_workers=0,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=False,
    eval_on_start=True,
    eval_strategy='epoch',
    save_strategy='best',
    logging_strategy='steps',
    logging_steps=100,
    load_best_model_at_end=True,
    push_to_hub=False
)

In [None]:
# create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    processing_class=processor,
    compute_metrics=compute_metrics
)

## Train model

In [None]:
# start training
results = trainer.train()