### Create LoRA Layer

In [7]:
import torch
import torch.nn as nn

class LoRA(nn.Module):
    def __init__(self, embed_dim, rank):
        super().__init__()
        self.rank = rank 

        # Low-rank trainable matrices
        self.A = nn.Parameter(torch.randn(embed_dim, rank) * 0.01) # Matrix A is intialized from a Gaussian distribution; 0.01 as standard deviation
        self.B = nn.Parameter(torch.zeros(rank, embed_dim)) # Matrix B is initialized to 0 

    def forward(self, X):
        # LoRA update: matrix multiplication
        X = X @ self.A @ self.B

        return X 

###

### Implement LoRA in Attention

In [8]:
import torch
import torch.nn as nn

class LoRAAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, rank):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.rank = rank 
        self.head_dim = embed_dim // num_heads

        # Original frozen weights matrices (Wq, Wk, Wv, Wo)
        self.Wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wk = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wv = nn.Linear(embed_dim, embed_dim, bias=False)
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=False)

        # LoRA low-rank adapaters for Wq and Wv; don't modify Wk, Wo
        self.lora_q = LoRA(embed_dim, rank)
        self.lora_v = LoRA(embed_dim, rank)
        
    def forward(self, X):

        # Compute query, key, value 
        Q = self.Wq(X) + self.lora_q(X) # LoRa for Q 
        K = self.Wk(X) # no LoRA for K 
        V = self.Wv(X) + self.lora_v(X) # LoRa for V 

        # Compute scaled dot-product attention 
        scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = scores.softmax(dim=-1)
        attn_output = attn_weights @ V 

        # Apply output project (Wo)
        output = self.Wo(attn_output)

        return output

### Implement LoRA in Transformer Block

In [9]:
# Load libraries
import torch 
import torch.nn as nn

class LoRAEncoderTransformerBlock(nn.Module):
    
    def __init__(self, embed_dim, num_heads, ff_dim, rank):
        super().__init__()

        # Define layers 
        self.lora_attention = LoRAAttention(embed_dim, num_heads, rank)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(), 
            nn.Linear(ff_dim, embed_dim)
        )

    # Forward pass
    def forward(self, X):

        # Multi-Head Attention
        attn_output = self.lora_attention(X)

        # Add & Norm
        X = self.norm1(X + attn_output)

        # Feed-forward
        ff_output = self.ff(X)

        # Add & Norm
        X = self.norm2(X + ff_output)

        return X

### Implement LoRA in Transformer (Encoder)

In [10]:
class LoRASimpleTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, max_len, rank):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim) # input embedding
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim)) # Positional encoding
        self.layers = nn.ModuleList([
            LoRAEncoderTransformerBlock(embed_dim, num_heads, ff_dim, rank) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(embed_dim, vocab_size) # final output

    def forward(self, X):
        # Apply positional encoding to input embeddings
        X = self.embed(X) + self.pos_embed[:, :X.shape[1], :]

        # Nx blocks 
        for layer in self.layers:
            X = layer(X)

        # Project to vocab size for classification
        X = self.output_layer(X)

        return X 

### Implement function to freeze weights

In [11]:
def freeze_original_weights(model):
    """
    Freezes all non-LoRA parameters so only LoRA parameters are trainable.
    """
    for name, param in model.named_parameters():
        if "lora" not in name:  # Only LoRA parameters should be trainable
            param.requires_grad = False

### Test Code

In [12]:
import torch
import torch.optim as optim 
import torch.nn as nn

# Model hyperparameters 
vocab_size = 1000
embed_dim = 128
num_heads = 4
ff_dim = 256 
num_layers = 2
max_len = 20
rank = 4 # LoRA rank 

# Initialize LoRA transformer
model = LoRASimpleTransformer(vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim, num_layers=num_layers, max_len=max_len, rank=rank)

# Freeze original Transformer weights (only train LoRA)
freeze_original_weights(model)

# Loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

# Dummy dataset (random tokens)
X_train = torch.randint(0, vocab_size, (32, max_len))  # Batch of 32
y_train = torch.randint(0, vocab_size, (32, max_len))  # Target

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X_train)  # Forward pass
    loss = criterion(outputs.view(-1, vocab_size), y_train.view(-1))  # Compute loss
    loss.backward()  # Backpropagation
    optimizer.step()  # Update only LoRA parameters

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/5], Loss: 7.0529
Epoch [2/5], Loss: 7.0143
Epoch [3/5], Loss: 6.9758
Epoch [4/5], Loss: 6.9365
Epoch [5/5], Loss: 6.8959
