In [87]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

device = torch.device("mps")

In [88]:
# simple tokenization by characters

lines = open('./input.txt', 'r').read()

vocab = list(set(lines))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])

print('vocab size:', len(vocab))

vocab size: 65


In [89]:
dataset = torch.tensor(encode(lines), dtype=torch.int8, device=device)
dataset

tensor([23, 64, 21,  ..., 42, 31, 18], device='mps:0', dtype=torch.int8)

We want to run training with minibatches of size 8.

In [90]:
BATCH_SIZE = 8
CONTEXT_WINDOW = 32

config = {
    "d_model": 100,
    "vocab_size": len(vocab),
    "max_len": CONTEXT_WINDOW,
}

def get_batches(data, batch_size, context_window):
    xs = torch.zeros(batch_size, context_window, device=device, dtype=torch.long)
    ys = torch.zeros(batch_size, context_window, device=device, dtype=torch.long)

    # pick random starting points
    starts = torch.randint(0, len(data) - context_window - 1, (batch_size,))
    for item, start in enumerate(starts):
        
        xs[item] += data[start:start+context_window]
        ys[item] += data[start+1:start+context_window+1]
    return xs, ys
    

xs, ys = get_batches(dataset, BATCH_SIZE, CONTEXT_WINDOW)

In [95]:
class MaskedRotarySelfAttentionHead(nn.Module):
    """
    Rotary embeddings.

    Input: (BATCH_SIZE x CONTEXT_WINDOW)
    Output: (BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM)
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        if config['d_model'] % 2 != 0:
            raise ValueError("d_model must be divisible by 2")
        self.w_q = torch.randn(config['d_model'], config['d_model'])
        self.w_k = torch.randn(config['d_model'], config['d_model'])
        self.w_v = torch.randn(config['d_model'], config['d_model'])
        self.register_buffer("tril", torch.tril(torch.ones(config['context_window'], config['context_window'])))  # mask
        
        self.R = self.get_rotary_matrix() # [d_model x d_model x context_window]

    def get_rotary_matrix(self):
        context_window = self.config['context_window']
        d = self.config['d_model']
        R = torch.zeros((d, d, context_window))
        for m in range(context_window):
            for i in range(d//2):
                theta = 10000. ** (-2.*(i - 1) / d)
                R[2*i,2*i, m] = np.cos(m * theta)
                R[2*i,2*i+1, m] = - np.sin(m * theta)
                R[2*i+1,2*i, m] = np.sin(m * theta)
                R[2*i+1,2*i+1, m] = np.cos(m * theta)
        return R

    def forward(self, x):
        """
        x: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        out: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        """
        b = self.config['batch_size']
        d = self.config['d_model']
        m = self.config['context_window']

        q_rotated_query_weight = self.w_q @ self.R # [d x d x m]
        q = q_rotated_query_weight.view(m, d, d) @ x.view(m, d, b)
        q = q.view(b, m, d)

        k_rotated_query_weight = self.w_k @ self.R # [d x d x m]
        k = k_rotated_query_weight.view(m, d, d) @ x.view(m, d, b)
        k = k.view(b, m, d)
        
        B = (q @ k.transpose(1,2)) / np.sqrt(self.config['d_model'])
        mask = B.masked_fill(self.tril[:m, :m] == 0, float("-inf"))
        a = F.softmax(mask, dim=-1) # attention
        v = x @ self.w_v 
        out = a @ v
        return out


config = {
    "batch_size": 3,
    "context_window": 10,
    "d_model": 512,
}
r = RotarySelfAttention(config)
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))

r(batch).shape

# ==== check ====
# s = batch[0,:,:]
# (r.w_q @ r.R)

torch.Size([3, 10, 512])

In [97]:
class MaskedRotaryNultiHeadedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            MaskedRotarySelfAttentionHead(config) for _ in range(config["n_heads"])
        ])
        self.linear = nn.Linear(config["n_heads"] * config["d_model"], config["d_model"])
        self.dropout = nn.Dropout(0.1)


    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.linear(out))
        return out

config = {
    "n_heads": 8,
    "d_model": 512,
    "batch_size": 3,
    "context_window": 10,
}

batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
m = MaskedRotaryNultiHeadedAttention(config)

m(batch).shape

torch.Size([3, 10, 512])

In [93]:

class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def forward(self, x):
        