In [348]:
import tiktoken
import torch

import torch.nn as nn
import torch.nn.functional as F

In [349]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
class Embedding(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, idx):
        B, T = idx.shape
        device = idx.device

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        emb = tok_emb + pos_emb
        emb = self.dropout(emb)
        return emb
class Model(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout):
        super().__init__()
        self.block_size = block_size

        self.transformer = nn.ModuleDict({
            "embedding": Embedding(vocab_size, block_size, n_embd, dropout),
            "blocks": nn.ModuleDict({f"block_{i}": Block(n_embd, n_head, block_size, dropout) for i in range(n_layer)}),
            "ln_f": nn.LayerNorm(n_embd),
            "lm_head": nn.Linear(n_embd, vocab_size)
        })

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        x = self.transformer.embedding(idx)
        for name, block in self.transformer.blocks.items():
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.transformer.lm_head(x)
        return logits

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [350]:
class LoRAHead(nn.Module):
    def __init__(self, head, head_size, n_embd, r):
        super().__init__()

        self.head = head

        self.key = head.key
        self.query = head.query
        self.value = head.value
        self.tril = head.tril
        self.dropout = head.dropout

        self.B_k = nn.Parameter(torch.zeros([head_size, r]), requires_grad=True)
        self.A_k = nn.Parameter(torch.randn([r, n_embd]), requires_grad=True)

        self.B_q = nn.Parameter(torch.zeros([head_size, r]), requires_grad=True)
        self.A_q = nn.Parameter(torch.randn([r, n_embd]), requires_grad=True)

        self.B_v = nn.Parameter(torch.zeros([head_size, r]), requires_grad=True)
        self.A_v = nn.Parameter(torch.randn([r, n_embd]), requires_grad=True)

    def forward(self, x):
        B, T, C = x.shape

        dW_k = self.B_k @ self.A_k
        d_k = x @ dW_k.T
        k = self.key(x) + d_k

        dW_q = self.B_q @ self.A_q
        d_q = x @ dW_q.T
        q = self.query(x) + d_q

        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        dW_v = self.B_q @ self.A_v
        d_v = x @ dW_v.T
        v = self.value(x) + d_v

        out = wei @ v

        return out
class LoRAMultiHeadAttention(nn.Module):
    def __init__(self, sa, num_heads, head_size, n_embd, r):
        super().__init__()

        self.heads = nn.ModuleList([LoRAHead(sa.heads[i], head_size, n_embd, r) for i in range(num_heads)])
        self.proj = sa.proj
        self.dropout = sa.dropout

        self.B_proj = nn.Parameter(torch.zeros([n_embd, r]), requires_grad=True)
        self.A_proj = nn.Parameter(torch.randn([r, num_heads * head_size]), requires_grad=True)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        dW_proj = self.B_proj @ self.A_proj
        d_proj = out @ dW_proj.T
        out = self.proj(out) + d_proj

        out = self.dropout(out)

        return out
class LoRABlock(nn.Module):
    def __init__(self, block, n_embd, n_head, r):
        super().__init__()

        head_size = n_embd // n_head
        self.sa = LoRAMultiHeadAttention(block.sa, n_head, head_size, n_embd, r)
        self.ffwd = block.ffwd
        self.ln1 = block.ln1
        self.ln2 = block.ln2

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
class LoRAModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, r, state_dict=None):
        super().__init__()

        self.model = Model(vocab_size, n_embd, block_size, n_head, n_layer, dropout)
        # self.model.load_state_dict(state_dict)
        for p in self.model.parameters():
            p.requires_grad = False

        self.lora_blocks = nn.ModuleDict({})
        for i, (name, block) in enumerate(self.model.transformer.blocks.items()):
            self.lora_blocks[f"lora_block_{i}"] = LoRABlock(block, n_embd, n_head, r)

    def forward(self, idx):
        x = self.model.transformer.embedding(idx)
        for (name, lora_block) in self.lora_blocks.items():
            x = lora_block(x)
        x = self.model.transformer.ln_f(x)
        logits = self.model.transformer.lm_head(x)
        return logits

    def generate(self, idx, max_new_tokens):
        return self.generate(idx, max_new_tokens)

In [364]:
tokenizer = tiktoken.encoding_for_model("gpt2")
VOCAB_SIZE = tokenizer.n_vocab
N_EMBD = 128
BLOCK_SIZE = 32
N_HEAD = 12
N_LAYER = 12
DROPOUT = 0.01
R = 10
model = LoRAModel(vocab_size=VOCAB_SIZE, n_embd=N_EMBD, block_size=BLOCK_SIZE, n_head=N_HEAD, n_layer=N_LAYER, dropout=DROPOUT, r=R)

In [365]:
n_model_parameters = sum(p.numel() for p in model.model.parameters())
n_lora_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: {n_model_parameters:,}".replace(",", "."))
print(f"LoRA: {n_lora_parameters:,}".replace(",", "."))
print(F"ratio: {(n_lora_parameters / n_model_parameters) * 100:.02f}%")

Model: 15.245.905
LoRA: 625.920
ratio: 4.11%


In [366]:
x = torch.zeros([16, BLOCK_SIZE], dtype=torch.long)
print(x.shape)
y = model(x)
print(y.shape)

torch.Size([16, 32])
torch.Size([16, 32, 50257])
