----
# LLM fine-tuning from scratch: Coding the Modifed LoRA and evaluating that on a task/benchmark

In this notebook, we will finetune a BERT model to perform sentiment analysis.

----

### Implementing our modified version of LoRA

In [None]:
import torch
from torch import nn
import math

class LoRALinear(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        # These are the weights from the original pretrained model
        self.linear = linear
        in_dim = linear.in_features
        out_dim = linear.out_features
        
        # These are the LoRA parameters
        std = 1 / math.sqrt(rank)
        self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * std, requires_grad= True)
        self.lora_B = nn.Parameter(torch.zeros(rank, out_dim), requires_grad= True)
        
        # Other parameters of lora
        self.rank = rank 
        self.alpha = alpha # This is our alpha parameter in the theory
        # we can also set: self.alpha = nn.Parameter(..) if we want to make it trainable
        # Freeze the linear layer and to only make lora parameters trainable
        self.linear.weight.requires_grad = False
    
    def forward(self, x):
        x = self.alpha * self.linear +  (x @ self.lora_A @ self.lora_B) / self.rank
        return x


### Loading the dataset

In [2]:
import os
from datasets import load_dataset

import pandas as pd
import torch

from processing.dataset_utils import download_dataset, load_dataset_into_to_dataframe, partition_dataset, IMDBDataset

In [3]:
if not torch.cuda.is_available():
    print("Please switch to a GPU machine before running this notebook.")

Please switch to a GPU machine before running this notebook.


In [None]:
files = ("test.csv", "train.csv", "val.csv")
download = True

for f in files:
    if not os.path.exists(os.path.join("data/sentiment", f)):
        download = False

if download is False:
    download_dataset()
    df = load_dataset_into_to_dataframe()
    partition_dataset(df)

100% | 80.23 MB | 4.28 MB/s | 18.75 sec elapsed

100%|██████████| 50000/50000 [00:26<00:00, 1888.82it/s]


Class distribution:
