In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.utils.parametrize as parametrize
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from datasets import load_dataset
from tqdm.auto import tqdm
import transformers
transformers.logging.set_verbosity_error()  # Suppress all messages except errors
import warnings
# Suppress future warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


# Implementation of LoRA in DistilBERT

This Jupyter notebook demonstrates the application of Low-Rank Adaptation (LoRA) to the DistilBERT model, specifically tailored for sequence classification tasks such as sentiment analysis on the IMDb dataset. LoRA offers an efficient alternative to traditional full model fine-tuning by introducing low-rank matrices that modify only a subset of the model's weights.

## Preliminary
- Using DistilBertForSequenceClassification

We use the `DistilBertForSequenceClassification` model from Hugging Face's Transformers library. Below is the structure of the model:

  ```
  DistilBertForSequenceClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (ffn): FFN(
              (dropout): Dropout(p=0.1, inplace=False)
              (lin1): Linear(in_features=768, out_features=3072, bias=True)
              (lin2): Linear(in_features=3072, out_features=768, bias=True)
              (activation): GELUActivation()
            )
            (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          )
        )
      )
    )
    (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
    (classifier): Linear(in_features=768, out_features=2, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  ```

  The model consists of several parts that contain linear layers, identified for potential adaptation using LoRA:
  ```
      - Encoder block
          - model.distilbert.transformer.layer.attention:
              - q_lin, k_lin, v_lin, out_lin
          - model.distilbert.transformer.layer.ffn:
              - lin1, lin2
      - model.pre_classifier
      - model.classifier
  ```

  According to the original LoRA paper, adapting only the attention weights for downstream tasks and freezing the MLP modules (Feed Forward Network) has shown to be effective. We will follow this guidance and apply LoRA to the `q_lin` (query) and `v_lin` (value) linear layers within the attention blocks of each Transformer layer.

- Using parametrize.register_parametrization

  The `parametrize.register_parametrization` method allows us to redefine how a parameter (like the weights of a linear layer) is computed, without changing the underlying model architecture. This is particularly useful for implementing techniques like LoRA, where we introduce low-rank matrices \(A\) and \(B\) to modify the original weights \(W\) of the model. The re-parameterized weight is calculated as \(W' = W + BA\), where \(B\) and \(A\) are smaller matrices that represent low-rank modifications to \(W\).

  The following sections of this notebook detail the implementation of the LoRA class, parameterization functions, and application methods to integrate LoRA into the DistilBERT model for efficient fine-tuning on the IMDb dataset.

In [2]:
# LoRA Parametrization Class
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super(LoRAParametrization, self).__init__()
        # Initialize A with random Gaussian values and B with zeros as per Section 4.1 of the LoRA paper
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)  # Normal initialization for A

        # Scale factor α/r as described, with no tuning of α beyond this setup
        self.scale = alpha / rank

    def forward(self, original_weights):
        delta_w = torch.matmul(self.lora_B, self.lora_A) * self.scale
        return original_weights + delta_w

# Parameterization Function
def linear_layer_parameterization(layer: nn.Linear, device, rank, lora_alpha):
    ''' 
    Update the original weight tensor W in layer with W' = W + B * A * scale using parametrize.register_parametrization()
    '''
    features_in, features_out = layer.weight.shape # Size of the Linear layer
    lora_param = LoRAParametrization(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)
    parametrize.register_parametrization(layer, "weight", lora_param, unsafe=True)

# Function to apply LoRA and freeze parameters and return the model
def apply_lora_and_freeze(model, device, config):
    '''
    Linear layer:
    - Encoderblock
        - model.distilbert.transformer.layer.attention:
            - q_lin, k_lin, v_lin, out_lin
        - model.distilbert.transformer.layer.fnn:
            - lin1, lin2
    - model.pre_classifier
    - model.classifier
    '''
    # Iterate over each transformer layer to apply LoRA based on config
    for layer in model.distilbert.transformer.layer:
        if config['lora_query']:
            linear_layer_parameterization(layer.attention.q_lin, device, config['lora_r'], config['lora_alpha'])
        if config['lora_key']:
            linear_layer_parameterization(layer.attention.k_lin, device, config['lora_r'], config['lora_alpha'])
        if config['lora_value']:
            linear_layer_parameterization(layer.attention.v_lin, device, config['lora_r'], config['lora_alpha'])
        if config['lora_projection']:
            linear_layer_parameterization(layer.attention.out_lin, device, config['lora_r'], config['lora_alpha'])
        if config['lora_mlp']:
            linear_layer_parameterization(layer.ffn.lin1, device, config['lora_r'], config['lora_alpha'])
            linear_layer_parameterization(layer.ffn.lin2, device, config['lora_r'], config['lora_alpha'])

    # Apply LoRA to classification heads if enabled
    if config['lora_head']:
        linear_layer_parameterization(model.pre_classifier, device, config['lora_r'], config['lora_alpha'])
        linear_layer_parameterization(model.classifier, device, config['lora_r'], config['lora_alpha'])

    # Freeze non-LoRA parameters
    for name, param in model.named_parameters():
        if 'lora' not in name:
            param.requires_grad = False
            #print(f'Freezing non-LoRA parameter {name}')
    
    return model

# Print out the requires_grad_status
def print_requires_grad_status(model):
    print("Requires grad status for model parameters:")
    for name, param in model.named_parameters():
        print(f"{name}: requires_grad={param.requires_grad}")



# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2, ignore_mismatched_sizes=True )
model.to(device)

# Configuration for LoRA application for DistilBERT
lora_config = {
    'lora_r': 8,
    'lora_alpha': 16,
    'lora_query': True,
    'lora_key': False,
    'lora_value': True,
    'lora_projection': False,
    'lora_mlp': False,
    'lora_head': False
}

print("Before adapt LoRA:")
print_requires_grad_status(model)
# Apply LoRA and freeze parameters
model = apply_lora_and_freeze(model, device, lora_config)
print("\nAfter adapt LoRA:")
print_requires_grad_status(model)





Before adapt LoRA:
Requires grad status for model parameters:
distilbert.embeddings.word_embeddings.weight: requires_grad=True
distilbert.embeddings.position_embeddings.weight: requires_grad=True
distilbert.embeddings.LayerNorm.weight: requires_grad=True
distilbert.embeddings.LayerNorm.bias: requires_grad=True
distilbert.transformer.layer.0.attention.q_lin.weight: requires_grad=True
distilbert.transformer.layer.0.attention.q_lin.bias: requires_grad=True
distilbert.transformer.layer.0.attention.k_lin.weight: requires_grad=True
distilbert.transformer.layer.0.attention.k_lin.bias: requires_grad=True
distilbert.transformer.layer.0.attention.v_lin.weight: requires_grad=True
distilbert.transformer.layer.0.attention.v_lin.bias: requires_grad=True
distilbert.transformer.layer.0.attention.out_lin.weight: requires_grad=True
distilbert.transformer.layer.0.attention.out_lin.bias: requires_grad=True
distilbert.transformer.layer.0.sa_layer_norm.weight: requires_grad=True
distilbert.transformer.layer

## Training and Evaluation

With the model prepared and LoRA applied, we now proceed to train the model on the IMDb dataset and evaluate its performance.


In [None]:
# Training Function
def train(model, train_loader, optimizer, device, epochs=3):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
            # Move batch to the appropriate device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)  # Ensure labels are correctly named and used

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss {total_loss / len(train_loader)}")

# Testing Function
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)  # This line is just for using the labels for accuracy calculation

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=1)

            # Calculate accuracy
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    print(f"Accuracy: {correct / total:.2f}")

# Load the full IMDb dataset and sample a smaller subset to test the code
dataset = load_dataset("imdb")
small_train_dataset = dataset['train'].shuffle(seed=42).select(range(200))  # Sample 200 examples for training
small_test_dataset = dataset['test'].shuffle(seed=42).select(range(100))   # Sample 100 examples for testing

# Initialize the tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
# Define the preprocessing function for tokenization
def preprocess_function(examples):
    return tokenizer(
        examples['text'], 
        truncation=True, 
        padding="max_length", 
        max_length=512, 
    )

# Tokenize the small datasets
small_train_dataset = small_train_dataset.map(preprocess_function, batched=True)
small_test_dataset = small_test_dataset.map(preprocess_function, batched=True)

# Set the format for PyTorch tensors
small_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
small_test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create DataLoader for both the training and testing datasets
train_loader = DataLoader(small_train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(small_test_dataset, batch_size=8, shuffle=False)

# Optimizer including only parameters with gradients
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)


# Train and Test
train(model, train_loader, optimizer, device)
test(model, test_loader, device)