# PEFT Tutorial
*(A bulk of the material of this tutorial is taken from Sebastian Raschka's [Code Lora from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).)*

In [1]:
import os
import time
from functools import partial

import lightning as L
import torch
import torch.nn.functional as F
from custom_lightning_module import CustomLightningModule
from datasets import load_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Configuring Paths


In [2]:
DATASET_DIR = "../../data/imdb/"
SAVED_MODEL_DIR = "/projects/fta_bootcamp/trained_models/peft_demo/"
OUTPUT_DIR = "../../scratch/peft/" # main directory of the the demo output
CHECKPOINT_DIR = f"{OUTPUT_DIR}checkpoints" # where to save checkpoints
MODEL_NAME = "distilbert-base-uncased"

## Our Custom LoRA Layer <a id="LoRA_Anchor"></a>

In [35]:
torch.set_float32_matmul_precision("medium")

class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        ### TODO: TODO: implement the forward pass of lora ###
        return self.alpha*torch.matmul(x, torch.matmul(self.W_a, self.W_b))


class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha,
        )

    def forward(self, x):
        ### TODO: TODO: implement the forward pass of lora layer ###
        return self.linear(x) + self.lora(x)

In [36]:
torch.manual_seed(123)

# a simple linear layer with 10 inputs and 1 output
# requires_grad=False makes it non-trainable
with torch.no_grad():
    linear_layer = torch.nn.Linear(10, 1)

# a simple example input
x = torch.rand((1, 10))

linear_layer(x)

tensor([[-0.5745]], grad_fn=<AddmmBackward0>)

In [37]:
lora_layer = LinearWithLoRA(linear=linear_layer, rank=8, alpha=1)
lora_layer(x)

tensor([[-0.5745]], grad_fn=<AddBackward0>)

In [38]:
lora_layer.lora.W_b = torch.nn.Parameter(lora_layer.lora.W_b + 0.01 * x[0])
lora_layer(x)

tensor([[-0.5863, -0.5758, -0.5779, -0.5800, -0.5814, -0.5766, -0.5887, -0.5811,
         -0.5859, -0.5892]], grad_fn=<AddBackward0>)

## Loading the Dataset into DataFrames

In [39]:
imdb_dataset = load_dataset(
    "csv",
    data_files={
        "train": os.path.join(DATASET_DIR, "train.csv"),
        "validation": os.path.join(DATASET_DIR, "val.csv"),
        "test": os.path.join(DATASET_DIR, "test.csv"),
    },
)

print(imdb_dataset)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 35000
    })
    validation: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 10000
    })
})


## Loading Tokenizer

In [40]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Tokenizer input max length:", tokenizer.model_max_length)
print("Tokenizer vocabulary size:", tokenizer.vocab_size)



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Tokenizer input max length: 512
Tokenizer vocabulary size: 30522


## Tokenizing Data

In [41]:
def tokenize_text(batch):
    return tokenizer(batch["text"], truncation=True, padding=True)

imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
del imdb_dataset
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Map:   0%|          | 0/35000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

## Setting Up DataLoaders

In [63]:
class IMDBDataset(Dataset):
    def __init__(self, dataset_dict, partition_key="train"):
        self.partition = dataset_dict[partition_key]

    def __getitem__(self, index):
        return self.partition[index]

    def __len__(self):
        return self.partition.num_rows

In [64]:
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=12,
    shuffle=True,
    num_workers=4,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=12,
    num_workers=4,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=12,
    num_workers=4,
)

## Counting Number of Trainable Parameters Function

In [65]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Finetunning Last Two Layers

In [66]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

print(f"Total number of trainable parameters for the base model: {count_parameters(model):,}" )

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total number of trainable parameters for the base model: 66,955,010


Freeze all the layers:

In [67]:
for param in model.parameters():
    param.requires_grad = False

Unfreeze the last two layers:

In [68]:
for param in model.pre_classifier.parameters():
    param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = True

print(f"Total number of trainable parameters: {count_parameters(model):,}" )

Total number of trainable parameters: 592,130


In [69]:
lightning_model = CustomLightningModule(model)
callbacks = [
    ModelCheckpoint(
        dirpath=CHECKPOINT_DIR,
        filename="last_two",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]

logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [70]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

Missing logger folder: logs/my-model
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                                | Params
-----------------------------------------------------------------
0 | model    | DistilBertForSequenceClassification | 67.0 M
1 | val_acc  | MulticlassAccuracy                  | 0     
2 | test_acc | MulticlassAccuracy                  | 0     
-----------------------------------------------------------------
592 K     Trainable params
66.4 M    Non-trainable params
67.0 M    Total params
267.820   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


Time elapsed 11.46 min


In [None]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo//last_two.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")

## Enter LoRA!

In [44]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

for param in model.parameters():
    param.requires_grad = False

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Let's use our [LoRA layer](#LoRA_Anchor) implementation from before. Here's our current model *before* adding LoRA layers:

In [45]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

Now let's wrap the query and value layers of transformer blocks with LoRA.

In [55]:
def apply_adaptation_layer(model, adaptation_layer, lora_r, lora_alpha, config):
    assign_lora = partial(adaptation_layer5, rank=lora_r, alpha=lora_alpha)

    for layer in model.distilbert.transformer.layer:
        if config.get("lora_query"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_key"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_value"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_projection"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_mlp"):
            ### look at the model architecture and and use assign_lora function (make sure you apply to both linear layers in fnn) ###
            pass
    if config.get("lora_head"):
        ### look at the model architecture and and use assign_lora function. Apply to both pre_classifier and classifier layers) ###
        pass

In [56]:
config = {
    "lora_query": True,
    "lora_key": False,
    "lora_value": True,
    "lora_projection": False,
    "lora_mlp": False,
    "lora_head": False,
}
apply_adaptation_layer(model, adaptation_layer=LinearWithLoRA, lora_r=8, lora_alpha=16, config=config)

Let's look at the model after the LoRA layers are added:

In [57]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [58]:
# Check if linear layers are frozen
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

distilbert.embeddings.word_embeddings.weight: False
distilbert.embeddings.position_embeddings.weight: False
distilbert.embeddings.LayerNorm.weight: False
distilbert.embeddings.LayerNorm.bias: False
distilbert.transformer.layer.0.attention.q_lin.weight: False
distilbert.transformer.layer.0.attention.q_lin.bias: False
distilbert.transformer.layer.0.attention.k_lin.weight: False
distilbert.transformer.layer.0.attention.k_lin.bias: False
distilbert.transformer.layer.0.attention.v_lin.weight: False
distilbert.transformer.layer.0.attention.v_lin.bias: False
distilbert.transformer.layer.0.attention.out_lin.weight: False
distilbert.transformer.layer.0.attention.out_lin.bias: False
distilbert.transformer.layer.0.sa_layer_norm.weight: False
distilbert.transformer.layer.0.sa_layer_norm.bias: False
distilbert.transformer.layer.0.ffn.lin1.weight: False
distilbert.transformer.layer.0.ffn.lin1.bias: False
distilbert.transformer.layer.0.ffn.lin2.weight: False
distilbert.transformer.layer.0.ffn.lin2.bi

In [59]:
print(f"Total number of trainable parameters: {count_parameters(model):,}" )

Total number of trainable parameters: 0


## Fine-Tune with LoRA

In [60]:
lightning_model = CustomLightningModule(model)
callbacks = [
    ModelCheckpoint(
        dirpath=CHECKPOINT_DIR,
        filename="lora",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]
logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [61]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

NameError: name 'train_loader' is not defined

In [62]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo/lora.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")

RuntimeError: Error(s) in loading state_dict for CustomLightningModule:
	Missing key(s) in state_dict: "model.distilbert.transformer.layer.0.attention.q_lin.weight", "model.distilbert.transformer.layer.0.attention.q_lin.bias", "model.distilbert.transformer.layer.0.attention.v_lin.weight", "model.distilbert.transformer.layer.0.attention.v_lin.bias", "model.distilbert.transformer.layer.1.attention.q_lin.weight", "model.distilbert.transformer.layer.1.attention.q_lin.bias", "model.distilbert.transformer.layer.1.attention.v_lin.weight", "model.distilbert.transformer.layer.1.attention.v_lin.bias", "model.distilbert.transformer.layer.2.attention.q_lin.weight", "model.distilbert.transformer.layer.2.attention.q_lin.bias", "model.distilbert.transformer.layer.2.attention.v_lin.weight", "model.distilbert.transformer.layer.2.attention.v_lin.bias", "model.distilbert.transformer.layer.3.attention.q_lin.weight", "model.distilbert.transformer.layer.3.attention.q_lin.bias", "model.distilbert.transformer.layer.3.attention.v_lin.weight", "model.distilbert.transformer.layer.3.attention.v_lin.bias", "model.distilbert.transformer.layer.4.attention.q_lin.weight", "model.distilbert.transformer.layer.4.attention.q_lin.bias", "model.distilbert.transformer.layer.4.attention.v_lin.weight", "model.distilbert.transformer.layer.4.attention.v_lin.bias", "model.distilbert.transformer.layer.5.attention.q_lin.weight", "model.distilbert.transformer.layer.5.attention.q_lin.bias", "model.distilbert.transformer.layer.5.attention.v_lin.weight", "model.distilbert.transformer.layer.5.attention.v_lin.bias". 
	Unexpected key(s) in state_dict: "model.distilbert.transformer.layer.0.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.0.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.0.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.0.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.0.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.0.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.0.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.0.attention.v_lin.lora.W_b", "model.distilbert.transformer.layer.1.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.1.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.1.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.1.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.1.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.1.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.1.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.1.attention.v_lin.lora.W_b", "model.distilbert.transformer.layer.2.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.2.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.2.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.2.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.2.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.2.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.2.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.2.attention.v_lin.lora.W_b", "model.distilbert.transformer.layer.3.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.3.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.3.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.3.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.3.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.3.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.3.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.3.attention.v_lin.lora.W_b", "model.distilbert.transformer.layer.4.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.4.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.4.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.4.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.4.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.4.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.4.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.4.attention.v_lin.lora.W_b", "model.distilbert.transformer.layer.5.attention.q_lin.linear.weight", "model.distilbert.transformer.layer.5.attention.q_lin.linear.bias", "model.distilbert.transformer.layer.5.attention.q_lin.lora.W_a", "model.distilbert.transformer.layer.5.attention.q_lin.lora.W_b", "model.distilbert.transformer.layer.5.attention.v_lin.linear.weight", "model.distilbert.transformer.layer.5.attention.v_lin.linear.bias", "model.distilbert.transformer.layer.5.attention.v_lin.lora.W_a", "model.distilbert.transformer.layer.5.attention.v_lin.lora.W_b". 

## Using HF's LoRA

We can replace our custom LoRA implementation with an implementation from the [peft library](https://github.com/huggingface/peft). Peft is an open-source, one-stop-shop library from HuggingFace for *parameter efficient fine-tuning* (PEFT) and is integrated with the their [transformers library](https://github.com/huggingface/transformers) for easy model training and inference. 

Here's a sample snippet for how to prepare a model for PEFT training with LoRA. We can easily fine-tune the DistillBert model we had before with this implementation of the LoRA layer, instead of our custom layer.

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
)

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## What's this [DoRA](https://arxiv.org/pdf/2402.09353) thing I keep hearing about?

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

for param in model.parameters():
    param.requires_grad = False

In [None]:
# Code inspired by https://github.com/catid/dora/blob/main/dora.py
class LinearWithDoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha,
        )

        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        lora = self.lora.W_a @ self.lora.W_b
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        column_norm = combined_weight.norm(p=2, dim=0, keepdim=True)
        V = combined_weight / column_norm
        new_weight = self.m * V
        return F.linear(x, new_weight, self.linear.bias)

In [None]:
config = {
    "lora_query": True,
    "lora_key": False,
    "lora_value": True,
    "lora_projection": False,
    "lora_mlp": False,
    "lora_head": False,
}
apply_adaptation_layer(model, adaptation_layer=LinearWithDoRA, lora_r=8, lora_alpha=16, config=config)

In [None]:
model

In [None]:
print(f"Total number of trainable parameters: {count_parameters(model):,}" )

## Finetune with DoRA

In [None]:
lightning_model = CustomLightningModule(model)

callbacks = [
    ModelCheckpoint(
        dirpath="",
        filename="dora",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]

logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

In [None]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

In [None]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo/dora.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")