In [1]:
import torch
import torch.nn as nn
from src.config import get_lora_config

from transformers import AutoModelForSequenceClassification, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, lora_config):
        super().__init__()
        self.lora_config = lora_config
        std_dev = 1 / torch.sqrt(torch.tensor(self.lora_config.r).float())
        self.lora_A = torch.nn.Parameter(
            torch.randn(in_dim, self.lora_config.r) * std_dev
        )
        self.lora_B = torch.nn.Parameter(torch.zeros(self.lora_config.r, out_dim))

    def forward(self, x):
        x = self.lora_config.lora_alpha * (x @ self.lora_A @ self.lora_B)
        return x


class LinearWithLoRA(nn.Module):
    def __init__(self, linear, lora_config):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, lora_config)

    def forward(self, x):
        return self.linear(x) + self.lora(x)


In [3]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

lora_config = get_lora_config(rank=1, target_mods="vo")

for i in range(12):
    if "query" in lora_config.target_modules:
        model.bert.encoder.layer[i].attention.self.query = LinearWithLoRA(
            model.bert.encoder.layer[i].attention.self.query, lora_config
        )
    if "key" in lora_config.target_modules:
        model.bert.encoder.layer[i].attention.self.key = LinearWithLoRA(
            model.bert.encoder.layer[i].attention.self.key, lora_config
        )
    if "value" in lora_config.target_modules:
        model.bert.encoder.layer[i].attention.self.value = LinearWithLoRA(
            model.bert.encoder.layer[i].attention.self.value, lora_config
        )
    if "output" in lora_config.target_modules:
        model.bert.encoder.layer[i].attention.output.dense = LinearWithLoRA(
            model.bert.encoder.layer[i].attention.output.dense, lora_config
        )

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
