# 09: Fine-Tune Decoder-Only Model with LoRA (Low-Rank Adaptation)
This notebook demonstrates how to use LoRA to fine-tune a GPT-style decoder model in a parameter-efficient way.

We'll:
- Load a base decoder model
- Inject LoRA adapters into the attention layer
- Freeze all base parameters and train only LoRA
- Evaluate training loss

In [None]:
!pip install torch transformers

In [None]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from models.lora_adapter import LoRALinear
import os

## Load and tokenize dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

if not os.path.exists("../data/tiny_shakespeare.txt"):
    from urllib.request import urlretrieve
    os.makedirs("../data", exist_ok=True)
    urlretrieve("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", "../data/tiny_shakespeare.txt")

with open("../data/tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()
tokens = tokenizer.encode(text, add_special_tokens=False)

## Create dataset

In [None]:
class TextDataset(Dataset):
    def __init__(self, tokens, block_size):
        self.examples = [
            torch.tensor(tokens[i:i+block_size+1])
            for i in range(len(tokens) - block_size - 1)
        ]
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        data = self.examples[idx]
        return data[:-1], data[1:]

block_size = 128
dataset = TextDataset(tokens, block_size)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

## Define decoder with LoRA in self-attention

In [None]:
class LoRAAttentionBlock(nn.Module):
    def __init__(self, embed_dim, heads):
        super().__init__()
        self.heads = heads
        self.head_dim = embed_dim // heads

        self.q_proj = LoRALinear(embed_dim, embed_dim, r=4)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = LoRALinear(embed_dim, embed_dim, r=4)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.ln = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        B, T, C = q.size()
        q = q.view(B, T, self.heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.heads, self.head_dim).transpose(1, 2)
        attn_weights = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = attn_weights.masked_fill(torch.tril(torch.ones(T, T, device=x.device)) == 0, float('-inf'))
        attn = torch.softmax(attn_weights, dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.ln(x + self.dropout(self.out_proj(out)))

In [None]:
class LoRAGPTDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, depth, heads, max_len):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        self.blocks = nn.ModuleList([
            nn.Sequential(
                LoRAAttentionBlock(embed_dim, heads),
                nn.Sequential(
                    nn.LayerNorm(embed_dim),
                    nn.Linear(embed_dim, embed_dim * 4),
                    nn.GELU(),
                    nn.Linear(embed_dim * 4, embed_dim),
                )
            ) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.token_embed(x) + self.pos_embed[:, :x.size(1)]
        for block in self.blocks:
            x = block(x)
        return self.head(self.norm(x))

## Fine-tune only LoRA params

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LoRAGPTDecoder(vocab_size=len(tokenizer), embed_dim=512, depth=4, heads=8, max_len=block_size).to(device)

# Freeze all but LoRA
for name, param in model.named_parameters():
    if 'lora' not in name:
        param.requires_grad = False

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    model.train()
    total_loss = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} LoRA fine-tuning loss: {total_loss / len(dataloader):.4f}")

torch.save(model.state_dict(), "lora_finetuned_decoder.pt")
print("✅ LoRA fine-tuned model saved.")