In [1]:
from transformers import RobertaForSequenceClassification, RobertaTokenizer
from lora_modules import LoRARobertaSdpaSelfAttention
from torch.nn.parameter import Parameter
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from utils import get_loss_and_accuracy, SST2Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


2025-01-19 10:21:36.868813: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-19 10:21:36.883868: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737282096.901575  145957 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737282096.908641  145957 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-19 10:21:36.929185: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
torch.manual_seed(42)
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
torch.manual_seed(42)
model_original = RobertaForSequenceClassification.from_pretrained('roberta-base')

model.to(device)
model.config.lora_rank = 8
model.config.lora_alpha = 8

model_original.to(device)
model_original.config.lora_rank = 8
model_original.config.lora_alpha = 8

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model_original

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
         

### Plugging in the custom LoRA attention module

In [4]:
LoRARobertaSdpaSelfAttention(model_original.config)

CustomRobertaSdpaSelfAttention(
  (query): LoRALinear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): LoRALinear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [3]:
for idx in range(len(model.roberta.encoder.layer)):
    lora_attention = LoRARobertaSdpaSelfAttention(model.config).to(device)

    # update query layer
    lora_attention.query.weight = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.query.weight))
    lora_attention.query.bias = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.query.bias))

    # update key layer
    lora_attention.key.weight = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.key.weight))
    lora_attention.key.bias = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.key.bias))

    # update value layer
    lora_attention.value.weight = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.value.weight))
    lora_attention.value.bias = Parameter(torch.clone(model.roberta.encoder.layer[idx].attention.self.value.bias))
 
    model.roberta.encoder.layer[idx].attention.self = lora_attention

### Sanity check

We want the result from the original model and the LoRA model to be the same

In [4]:
train_dataset = torch.load('./datasets/train_dataset.pth')
val_dataset = torch.load('./datasets/val_dataset.pth')

  train_dataset = torch.load('./datasets/train_dataset.pth')
  val_dataset = torch.load('./datasets/val_dataset.pth')


In [7]:
get_loss_and_accuracy(model=model_original, dataset=val_dataset, device=device, shuffle=False)

(0.7031794318131038, 0.49107142857142855)

In [8]:
get_loss_and_accuracy(model=model, dataset=val_dataset, device=device, shuffle=False)

(0.7031794318131038, 0.49107142857142855)

The figures are equal. Very nice!

### Training LoRA model

We only train the LoRA parameters, so we are freezing all the params and then unfreeze the LoRA ones.

In [5]:
for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if "lora_A" in name or "lora_B" in name:
        param.requires_grad = True

In [6]:
print("Trainable parameter count:")
print(f"Original model: {sum(p.numel() for p in model_original.parameters() if p.requires_grad):,}")
print(f"LoRA model: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Trainable parameter count:


NameError: name 'model_original' is not defined

We have ~0.3M trainable parameters, which matches the figure mentioned in the paper.

In [7]:
num_epochs = 60
batch_size = 16
learning_rate = 5e-4

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

total_steps = num_epochs * len(train_dataloader)
warmup_ratio = 0.06
warmup_steps = warmup_ratio * total_steps

print('========= RUN PARAMETERS: ===============')
print(f'Learning rate: {learning_rate:.1e}, batch size: {batch_size}')

optimizer = AdamW(
    params=model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)

warmup_scheduler = LinearLR(
    optimizer=optimizer,
    start_factor=0.05,
    end_factor=1.0,
    total_iters=warmup_steps,
)
decay_scheduler = LinearLR(
    optimizer=optimizer,
    start_factor=1.0,
    end_factor=0.0,
    total_iters=total_steps - warmup_steps,
)

scheduler = SequentialLR(
    optimizer=optimizer,
    schedulers=[warmup_scheduler, decay_scheduler],
    milestones=[warmup_steps]
)


# training loop
train_losses, val_losses, accuracies = [], [], []
for epoch in range(num_epochs):
    model.train()

    for batch_idx, (x, y) in enumerate(train_dataloader):

        # forward
        logits = model(**x).logits
        loss = F.cross_entropy(logits, y)

        # backprop
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        if batch_idx % 200 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}, LR: {scheduler.get_last_lr()[0]:.1e}")
    
    # get train loss
    train_loss, _ = get_loss_and_accuracy(
        model=model,
        dataset=train_dataset,
        device=device,
        eval_ratio=0.1 # evaluate on 10% of train data
    )

    # Get validation loss and accuracy
    val_loss, val_accuracy = get_loss_and_accuracy(
        model=model,
        dataset=val_dataset,
        device=device
    )

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    accuracies.append(val_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}] summary")
    print(f'    Train loss: {train_loss:.4f}')
    print(f'    Validation loss: {val_loss:.4f}')
    print(f'    Accuracy: {val_accuracy:.4f}')
    print('========================================')

print(f'Best accuracy is {max(accuracies):.4f} at epoch {accuracies.index(max(accuracies)) + 1}')


Learning rate: 5.0e-04, batch size: 16
Epoch [1/60], Step [1/4210], Loss: 0.6983, LR: 2.5e-05
Epoch [1/60], Step [201/4210], Loss: 0.7243, LR: 3.1e-05
Epoch [1/60], Step [401/4210], Loss: 0.6903, LR: 3.8e-05
Epoch [1/60], Step [601/4210], Loss: 0.6011, LR: 4.4e-05
Epoch [1/60], Step [801/4210], Loss: 0.4438, LR: 5.0e-05
Epoch [1/60], Step [1001/4210], Loss: 0.2743, LR: 5.6e-05
Epoch [1/60], Step [1201/4210], Loss: 0.1618, LR: 6.3e-05
Epoch [1/60], Step [1401/4210], Loss: 0.2986, LR: 6.9e-05
Epoch [1/60], Step [1601/4210], Loss: 0.3885, LR: 7.5e-05
Epoch [1/60], Step [1801/4210], Loss: 0.3012, LR: 8.1e-05
Epoch [1/60], Step [2001/4210], Loss: 0.1849, LR: 8.8e-05
Epoch [1/60], Step [2201/4210], Loss: 0.0700, LR: 9.4e-05
Epoch [1/60], Step [2401/4210], Loss: 0.3770, LR: 1.0e-04
Epoch [1/60], Step [2601/4210], Loss: 0.0838, LR: 1.1e-04
Epoch [1/60], Step [2801/4210], Loss: 0.8829, LR: 1.1e-04
Epoch [1/60], Step [3001/4210], Loss: 0.1688, LR: 1.2e-04
Epoch [1/60], Step [3201/4210], Loss: 0.



Epoch [4/60], Step [2601/4210], Loss: 0.0784, LR: 5.0e-04
Epoch [4/60], Step [2801/4210], Loss: 0.1174, LR: 5.0e-04
Epoch [4/60], Step [3001/4210], Loss: 0.1685, LR: 5.0e-04
Epoch [4/60], Step [3201/4210], Loss: 0.4681, LR: 5.0e-04
Epoch [4/60], Step [3401/4210], Loss: 0.4203, LR: 5.0e-04
Epoch [4/60], Step [3601/4210], Loss: 0.1513, LR: 5.0e-04
Epoch [4/60], Step [3801/4210], Loss: 0.0809, LR: 5.0e-04
Epoch [4/60], Step [4001/4210], Loss: 0.1562, LR: 5.0e-04
Epoch [4/60], Step [4201/4210], Loss: 0.1776, LR: 5.0e-04
Epoch [4/60] summary
    Train loss: 0.0119
    Validation loss: 0.2004
    Accuracy: 0.9364
Epoch [5/60], Step [1/4210], Loss: 0.1011, LR: 5.0e-04
Epoch [5/60], Step [201/4210], Loss: 0.3077, LR: 5.0e-04
Epoch [5/60], Step [401/4210], Loss: 0.0949, LR: 5.0e-04
Epoch [5/60], Step [601/4210], Loss: 0.3268, LR: 5.0e-04
Epoch [5/60], Step [801/4210], Loss: 0.1563, LR: 4.9e-04
Epoch [5/60], Step [1001/4210], Loss: 0.5134, LR: 4.9e-04
Epoch [5/60], Step [1201/4210], Loss: 0.0888

KeyboardInterrupt: 