In [None]:
from transformers import RobertaForSequenceClassification, RobertaTokenizer
from lora_modules import LoRARobertaSdpaSelfAttention
from torch.nn.parameter import Parameter
import torch
from utils import get_loss_and_accuracy, SST2Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [13]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
torch.manual_seed(42)
model_lora = RobertaForSequenceClassification.from_pretrained('roberta-base')
torch.manual_seed(42)
model_original = RobertaForSequenceClassification.from_pretrained('roberta-base')

model_lora.to(device)
model_lora.config.lora_rank = 8
model_lora.config.lora_alpha = 8

model_original.to(device)
model_original.config.lora_rank = 8
model_original.config.lora_alpha = 8

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model_original

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
         

### Plugging in the custom LoRA attention module

In [15]:
LoRARobertaSdpaSelfAttention(model_original.config)

CustomRobertaSdpaSelfAttention(
  (query): LoRALinear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): LoRALinear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [16]:
for idx in range(len(model_lora.roberta.encoder.layer)):
    lora_attention = LoRARobertaSdpaSelfAttention(model_lora.config).to(device)

    # update query layer
    lora_attention.query.weight = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.query.weight))
    lora_attention.query.bias = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.query.bias))

    # update key layer
    lora_attention.key.weight = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.key.weight))
    lora_attention.key.bias = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.key.bias))

    # update value layer
    lora_attention.value.weight = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.value.weight))
    lora_attention.value.bias = Parameter(torch.clone(model_lora.roberta.encoder.layer[idx].attention.self.value.bias))
 
    model_lora.roberta.encoder.layer[idx].attention.self = lora_attention

### Sanity check

We want the result from the original model and the LoRA model to be the same

In [17]:
val_dataset = torch.load('./datasets/val_dataset.pth')

  val_dataset = torch.load('./datasets/val_dataset.pth')


In [18]:
get_loss_and_accuracy(model=model_original, dataset=val_dataset, device=device, shuffle=False)

(0.7031794318131038, 0.49107142857142855)

In [19]:
get_loss_and_accuracy(model=model_lora, dataset=val_dataset, device=device, shuffle=False)

(0.7031794318131038, 0.49107142857142855)

The figures are equal. Very nice!