In [22]:
import torch
import re
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, TensorDataset
import transformers
from tqdm import tqdm
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## Load and Process Dataset

In [3]:
dataset = load_dataset("imdb", split="train")

In [4]:
def process(x):
    x = re.sub('[,\.!?:()"]', '', x)
    x = re.sub('<.*?>', ' ', x)
    x = re.sub('http\S+', ' ', x)
    x = re.sub('[^a-zA-Z0-9]', ' ', x)
    x = re.sub('\s+', ' ', x)
    return x.lower().strip()

In [5]:
dataset = dataset.map(lambda x: {'text': process(x['text'])})
dataset = dataset.filter(lambda x: len(x['text']) > 0)

In [6]:
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512, padding="max_length")

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

In [7]:
tokenized_datasets = tokenized_datasets.shuffle(seed=42)
train_dataset = tokenized_datasets.select(range(10000))
test_dataset = tokenized_datasets.select(range(10000, 15000))

print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))

Train dataset size: 10000
Test dataset size: 5000


In [8]:
train_dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 10000
})

In [9]:
train_df = pd.DataFrame(train_dataset["label"], columns=["label"])
test_df = pd.DataFrame(test_dataset["label"], columns=["label"])

print("Train dataset:")
print(train_df["label"].value_counts())
print("Test dataset:")
print(test_df["label"].value_counts())

Train dataset:
0    5004
1    4996
Name: label, dtype: int64
Test dataset:
1    2504
0    2496
Name: label, dtype: int64


In [10]:
def create_dataloader(dataset, batch_size=32, shuffle=True):
    input_ids = dataset["input_ids"]
    attention_mask = dataset["attention_mask"]
    labels = dataset["label"]

    tensor_dataset = TensorDataset(input_ids, attention_mask, labels)
    dataloader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader

In [11]:
train_dataloader = create_dataloader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = create_dataloader(test_dataset, batch_size=32, shuffle=False)

## Load model

In [12]:
model = transformers.DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

## Set up Training / Eval Code

In [16]:
def train(model, train_dataloader, num_epochs=5):
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0

        for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}"):
            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            labels = batch[2].to(device)

            optimizer.zero_grad()

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs.loss
            logits = outputs.logits

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
        
        accuracy = correct_predictions / total_predictions
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

In [17]:
def eval(model, test_dataloader):
    model.eval()
    model.to(device)
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating"):
            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            labels = batch[2].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)

    accuracy = correct_predictions / total_predictions
    print(f"Test Accuracy: {accuracy:.4f}")

## Normal finetune of head of model

In [18]:
# Only finetune the classifier head
for param in model.distilbert.parameters():
    param.requires_grad = False
for param in model.pre_classifier.parameters():
    param.requires_grad = True
for param in model.classifier.parameters():
    param.requires_grad = True

In [19]:
train(model, train_dataloader, num_epochs=5)

Training Epoch 1: 100%|██████████| 625/625 [03:52<00:00,  2.69it/s]


Epoch 1/5 - Loss: 0.5311, Accuracy: 0.7643


Training Epoch 2: 100%|██████████| 625/625 [03:51<00:00,  2.70it/s]


Epoch 2/5 - Loss: 0.4014, Accuracy: 0.8256


Training Epoch 3: 100%|██████████| 625/625 [03:51<00:00,  2.70it/s]


Epoch 3/5 - Loss: 0.3740, Accuracy: 0.8329


Training Epoch 4: 100%|██████████| 625/625 [03:51<00:00,  2.70it/s]


Epoch 4/5 - Loss: 0.3659, Accuracy: 0.8386


Training Epoch 5: 100%|██████████| 625/625 [03:51<00:00,  2.70it/s]

Epoch 5/5 - Loss: 0.3544, Accuracy: 0.8444





In [20]:
eval(model, test_dataloader)

Evaluating: 100%|██████████| 157/157 [01:34<00:00,  1.65it/s]

Test Accuracy: 0.8564





## Set up LoRA

In [32]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=4, alpha=1.0):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0.0, std=1)
        nn.init.zeros_(self.lora_B)

        self.scale = alpha / rank
        self.enabled = True
    
    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [33]:
def linear_layer_parametrization(layer, device, rank=4, alpha=1.0):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(features_in, features_out, rank, alpha)

## Implement LoRA

In [50]:
model_lora = transformers.DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model_lora.to(device);

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
for param in model_lora.distilbert.parameters():
    param.requires_grad = False
for param in model_lora.pre_classifier.parameters():
    param.requires_grad = True
for param in model_lora.classifier.parameters():
    param.requires_grad = True

In [52]:
for idx, layer in enumerate(model_lora.distilbert.transformer.layer):
    parametrize.register_parametrization(layer.attention.q_lin, "weight", linear_layer_parametrization(layer.attention.q_lin, device))
    parametrize.register_parametrization(layer.attention.k_lin, "weight", linear_layer_parametrization(layer.attention.k_lin, device))
    parametrize.register_parametrization(layer.ffn.lin1, "weight", linear_layer_parametrization(layer.ffn.lin1, device))
    parametrize.register_parametrization(layer.ffn.lin2, "weight", linear_layer_parametrization(layer.ffn.lin2, device))

In [53]:
model_lora

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): ParametrizedLinear(
              in_features=768, out_features=768, bias=True
              (parametrizations): ModuleDict(
                (weight): ParametrizationList(
                  (0): LoRAParametrization()
                )
              )
            )
            (k_lin): ParametrizedLinear(
              in_features=768, out_features=768, bias=True
              (parametrizations): ModuleDict(
                (weight): Parametri

In [54]:
def enable_disable_lora(enabled=True):
    for idx, layer in enumerate(model_lora.distilbert.transformer.layer):
        layer.attention.q_lin.parametrizations["weight"][0].enabled = enabled
        layer.attention.k_lin.parametrizations["weight"][0].enabled = enabled
        layer.ffn.lin1.parametrizations["weight"][0].enabled = enabled
        layer.ffn.lin2.parametrizations["weight"][0].enabled = enabled

In [55]:
for name, param in model_lora.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

distilbert.transformer.layer.0.attention.q_lin.parametrizations.weight.0.lora_A torch.Size([4, 768])
distilbert.transformer.layer.0.attention.q_lin.parametrizations.weight.0.lora_B torch.Size([768, 4])
distilbert.transformer.layer.0.attention.k_lin.parametrizations.weight.0.lora_A torch.Size([4, 768])
distilbert.transformer.layer.0.attention.k_lin.parametrizations.weight.0.lora_B torch.Size([768, 4])
distilbert.transformer.layer.0.ffn.lin1.parametrizations.weight.0.lora_A torch.Size([4, 768])
distilbert.transformer.layer.0.ffn.lin1.parametrizations.weight.0.lora_B torch.Size([3072, 4])
distilbert.transformer.layer.0.ffn.lin2.parametrizations.weight.0.lora_A torch.Size([4, 3072])
distilbert.transformer.layer.0.ffn.lin2.parametrizations.weight.0.lora_B torch.Size([768, 4])
distilbert.transformer.layer.1.attention.q_lin.parametrizations.weight.0.lora_A torch.Size([4, 768])
distilbert.transformer.layer.1.attention.q_lin.parametrizations.weight.0.lora_B torch.Size([768, 4])
distilbert.trans

In [56]:
enable_disable_lora(enabled=True)

## Train LoRA model

In [57]:
train(model_lora, train_dataloader, num_epochs=3)

Training Epoch 1: 100%|██████████| 625/625 [10:18<00:00,  1.01it/s]


Epoch 1/3 - Loss: 0.3689, Accuracy: 0.8340


Training Epoch 2: 100%|██████████| 625/625 [09:13<00:00,  1.13it/s]


Epoch 2/3 - Loss: 0.2647, Accuracy: 0.8906


Training Epoch 3: 100%|██████████| 625/625 [09:10<00:00,  1.14it/s]

Epoch 3/3 - Loss: 0.2423, Accuracy: 0.9014





In [58]:
eval(model_lora, test_dataloader)

Evaluating: 100%|██████████| 157/157 [01:36<00:00,  1.63it/s]

Test Accuracy: 0.9090





## Analysis

By implementing LoRA on the Q, K and FF layers, we are able to achieve much better test performance than just finetuning the classifier layers. Additionally, finetuning the entire model is not feasible on my laptop. Instead LoRA allows us to achieve that balance and still get the performance boost! 