In [1]:
import time
import tiktoken
import torch
import torch.nn as nn

In [None]:
from deps.gpt_model import LayerNorm, GELU, FeedForward

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # old code
        assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_len, context_len), diagonal=1),
            persistent=False
        )

        # new code
        self.register_buffer('cache_k', None, persistent=False)
        self.register_buffer('cache_v', None, persistent=False)
        self.ptr_cur_pos = 0
    
    def forward(self, x):
        # old code
        b, num_tokens, d_in = x.shape

        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        # (b, n, d) -> (b, n, heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # new code
        if self.cache_k is None:
            self.cache_k, self.cache_v = keys, values
        else:
            self.cache_k = torch.cat([self.cache_k, keys], dim=1)
            self.cache_v = torch.cat([self.cache_v, values], dim=1)
        
        keys, values = self.cache_k, self.cache_v

        # old code
        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attn_scores = queries @ keys.transpose(2,3)

        # new code
        num_tokens_q = queries.shape[-2]
        num_tokens_k = keys.shape[-2]
        mask_bool = self.mask.bool()
        # original was: mask_bool[:num_tokens_q, :num_tokens_k]
        mask_bool = mask_bool[
            self.ptr_cur_pos:self.ptr_cur_pos+num_tokens_q,
            :num_tokens_k
        ]
        self.ptr_cur_pos += num_tokens_q

        # old code
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec
    
    # new code
    def reset_cache(self):
        self.cache_k, self.cache_v = None, None
        self.ptr_cur_pos = 0

In [11]:
# use new MHA
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.attn = MultiHeadAttention(
            d_in=cfg['emb_dim'],
            d_out=cfg['emb_dim'],
            context_len=cfg['context_len'],
            num_heads=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            qkv_bias=cfg['qkv_bias'],
        )

        self.ff = FeedForward(cfg)
        
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])

        self.drop_shortcut = nn.Dropout(cfg['drop_rate'])
    
    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x

In [12]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        # old code
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_len'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        # slightly new code
        self.trf_blocks = nn.ModuleList([
            TransformerBlock(cfg) for _ in range(cfg['n_layers'])
        ])
        
        self.cur_pos = 0

        # old code
        self.final_norm = LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)
    
    def forward(self, in_idx):
        # old code
        batch_size, seq_len = in_idx.shape
        tok_embs = self.tok_emb(in_idx)
        
        # new code
        # before: torch.arange(seq_len, ...)
        pos_ids = torch.arange(
            self.cur_pos, self.cur_pos + seq_len,
            device=in_idx.device,
            dtype=torch.long,
        )
        self.cur_pos += seq_len
        pos_embs = self.pos_emb(pos_ids).unsqueeze(0)

        # old code
        x = tok_embs + pos_embs
        x = self.drop_emb(x)

        # new code
        for blk in self.trf_blocks:
            x = blk(x)

        # old code
        x = self.final_norm(x)
        logits = self.out_head(x)
        
        return logits
    
    # new code
    def reset_kv_cache(self):
        for blk in self.trf_blocks:
            blk.attn.reset_cache()
        self.cur_pos = 0

In [13]:
def generate_text_simple_cached(
        model, idx, max_new_tokens,
        context_size=None,
):
    model.eval()
    ctx_len = context_size or model.pos_emb.num_embeddings

    with torch.no_grad():
        # initialize cache with full prompt
        model.reset_kv_cache()
        logits = model(idx[:, -ctx_len:])
        
        for _ in range(max_new_tokens):
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)
            
            # feed only the new token
            logits = model(next_idx)
    
    return idx

In [14]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "context_len": 1024,  # Context length
    "emb_dim": 768,          # Embedding dimension
    "n_heads": 12,           # Number of attention heads
    "n_layers": 12,          # Number of layers
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": False        # Query-Key-Value bias
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

start_context = 'Hello, I am'
tokenizer = tiktoken.get_encoding('gpt2')
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)

print(f'\n{'='*50}\n{' '*22}IN\n{'='*50}\n')
print(f'Input text: {start_context}')
print(f'Encoded: {encoded}')
print(f'Encoded shape: {encoded_tensor.shape}')

start = time.time()

token_ids = generate_text_simple_cached(
    model,
    encoded_tensor,
    max_new_tokens=200,
)

end = time.time()

decoded = tokenizer.decode(token_ids.squeeze(0).tolist())
print(f'\n{'='*50}\n{' '*22}OUT\n{'='*50}\n')
print(f'Output: {token_ids}')
print(f'Output length: {len(token_ids[0])}')
print(f'Output text: {decoded}')

print(f'Took {end-start:.2f}s')
toks_per_sec = len(token_ids[0]) / (end-start)
print(f'{int(toks_per_sec)} tokens/sec')


                      IN

Input text: Hello, I am
Encoded: [15496, 11, 314, 716]
Encoded shape: torch.Size([1, 4])

                      OUT

Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267,
         49706, 43231, 47062, 34657, 18631, 49188, 43312, 45933, 23154, 15983,
         10345, 16369, 46214, 22954, 34674, 21100,  4743, 14056, 42526,  6459,
         12799,  5734, 49274,   136, 49294, 42900, 21193, 20463,  1018,  7864,
         13895, 27167, 12810, 25727, 14388,   985, 15797, 24440, 18557, 48625,
         10579,  4007, 11895, 45365, 19051,  1355, 47705,  5120, 32858, 49293,
          5141, 22900, 36570, 22215, 16369, 25803,  9254, 33694, 23188, 21624,
         12696,  1697, 12315, 23338,  1361, 49487, 27970, 21641, 28170, 36226,
          8980, 34715, 15683, 21370,   829, 41165, 19250, 40921, 47972, 29169,
         17681, 13937,   719,  7781, 46519, 39685, 35637, 38254, 37355, 48054,
          6960, 32389, 49945, 48307, 43363,  9451, 44360, 