In [1]:
import math; import copy
#torch
import torch; import torch.nn as nn
from torch.nn import functional as F

In [3]:
#Reference: Karpathy's Shakespeare GPT
class CodeLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head)
                                      for _ in range(n_layer)])
        self.norm = nn.LayerNorm(n_embd)
        self.generator = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)
    #apply weights: important according to Karpathy    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    #to train model        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        #Embeddings
        tok_emb = self.token_embedding(idx) # (B,T,C)
        pos_emb = self.position_embedding(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C) = token_embed + positional_embedding
        #Apply Decoder Layers + LayerNorm
        x = self.norm(self.blocks(x))
        #Final Output
        logits = self.generator(x) # (B,T,tgt_vocab)
        #Calculate Loss: pred word v.s. next word
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    #to generate code given pref-code
    def generate(self, idx, max_new_tokens=200):
        #B, T = idx.shape
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            if idx_next.cpu().tolist()[0][0] == PAD[0]:
                break
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [12]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        head_size = n_embd // n_head
        #choose either MHA or MQA
        #self.sa = MultiHeadAttention(n_head, head_size)
        self.sa = MultiQueryAttention(n_head, head_size)
        self.ff = PositionwiseFeedForward(n_embd)
        self.norm1 = nn.LayerNorm(n_embd)
        self.norm2 = nn.LayerNorm(n_embd)
    def forward(self, x):
        #normalize by feature, self-attention + residual
        x = x + self.sa(self.norm1(x))
        #normalize by feature, feed-forward + residual
        x = x + self.ff(self.norm2(x))
        return x
    
class OneHeadAttention(nn.Module):
    def __init__(self, head_size, n_embd, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        #wq, wk and wv
        self.linears = clones(nn.Linear(n_embd, head_size, bias=False), 3)
        #mask -> to discuss
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    def forward(self, x):
        #T can differ during eval
        B, T, C = x.shape
        query, key, value = [
            l(x) for l in self.linears
        ]
        d_k = query.size(-1)
        #apply the attn equation
        wei = (query @ key.transpose(-2, -1))/math.sqrt(d_k)
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = self.dropout(F.softmax(wei, dim=-1))
        wei = wei @ value
        #delete variables to save space
        del query, key, value
        return wei 

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, dropout=0.1):
        super().__init__()
        self.heads = nn.ModuleList([OneHeadAttention(head_size, n_embd) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class MultiQueryAttention(nn.Module):
    def __init__(self, num_heads, head_size,  dropout=0.1):
        super().__init__()
        self.n_embed = num_heads * head_size
        self.num_heads, self.head_size = num_heads, head_size
        self.dropout = nn.Dropout(dropout)
        #wq (multiple heads), wk (1 head), wv(1 head), and wo
        self.wq = nn.Linear(self.n_embed, self.n_embed).to(device)
        self.wk = nn.Linear(self.n_embed, self.head_size).to(device)
        self.wv = nn.Linear(self.n_embed, self.head_size).to(device)
        self.wo = nn.Linear(self.n_embed, self.n_embed).to(device)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_size])).to(device)
        self.tril = torch.concat([torch.tril(torch.ones(batch_size, block_size))
                                  for _ in range(num_scripts)], dim=0).to(device)
    def forward(self, x):
        B, T, C = x.shape
        #x = [batch size, xlen, hid dim]
        #query, key, value
        query = self.wq(x).view(B, -1, self.num_heads, self.head_size).permute(0, 2, 1, 3)
        key = self.wk(x).view(B, -1, 1, self.head_size).permute(0, 2, 3, 1)
        value = self.wv(x).view(B, -1, 1, self.head_size).permute(0, 2, 1, 3)
        #apply attention equation
        scores = (query @ key)/self.scale #[batch size, n heads, query len, key len]
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        p_attn = F.softmax(scores, dim = -1) #[batch size, n heads, query len, key len]
        x = self.dropout(p_attn) @ value
        #reshape
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(B, -1, self.n_embed)
        #delete variables to save space
        del query, key, value, scores, p_attn
        return self.wo(x)

class PositionwiseFeedForward(nn.Module):
    def __init__(self, n_embd, dropout=0.2):
        super().__init__()
        self.w1 = nn.Linear(n_embd, n_embd * 4)
        self.w2 = nn.Linear(n_embd * 4, n_embd)
        self.dropout =  nn.Dropout(dropout)
    def forward(self, x):
        x = F.relu(self.w1(x))
        x = self.dropout(x)
        return self.w2(x)

#helper
def clones(layer, N):
    return nn.ModuleList([
        copy.deepcopy(layer)
        for _ in range(N)
    ])