In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

device = torch.device("mps")

In [2]:
# simple tokenization by characters

lines = open('./input.txt', 'r').read()

vocab = list(set(lines))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])

print('vocab size:', len(vocab))

vocab size: 65


In [3]:
dataset = torch.tensor(encode(lines), dtype=torch.int8, device=device)
dataset

tensor([23, 64, 21,  ..., 42, 31, 18], device='mps:0', dtype=torch.int8)

We want to run training with minibatches of size 8.

In [4]:
BATCH_SIZE = 8
CONTEXT_WINDOW = 32

config = {
    "d_model": 100,
    "vocab_size": len(vocab),
    "max_len": CONTEXT_WINDOW,
}

def get_batches(data, batch_size, context_window):
    xs = torch.zeros(batch_size, context_window, device=device, dtype=torch.long)
    ys = torch.zeros(batch_size, context_window, device=device, dtype=torch.long)

    # pick random starting points
    starts = torch.randint(0, len(data) - context_window - 1, (batch_size,))
    for item, start in enumerate(starts):
        
        xs[item] += data[start:start+context_window]
        ys[item] += data[start+1:start+context_window+1]
    return xs, ys
    

xs, ys = get_batches(dataset, BATCH_SIZE, CONTEXT_WINDOW)

In [13]:
class RotarySelfAttention(nn.Module):
    """
    Rotary embeddings.

    Input: (BATCH_SIZE x CONTEXT_WINDOW)
    Output: (BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM)
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        if config['d_model'] % 2 != 0:
            raise ValueError("d_model must be divisible by 2")
        self.w_q = torch.randn(config['d_model'], config['d_model'])
        self.w_k = torch.randn(config['d_model'], config['d_model'])
        self.w_v = torch.randn(config['d_model'], config['d_model'])
        self.R = self.get_rotary_matrix() # [d_model x d_model x context_window]

    def get_rotary_matrix(self):
        context_window = self.config['context_window']
        d = self.config['d_model']
        R = torch.zeros((d, d, context_window))
        for m in range(context_window):
            for i in range(d//2):
                theta = 10000. ** (-2.*(i - 1) / d)
                R[2*i,2*i, m] = np.cos(m * theta)
                R[2*i,2*i+1, m] = - np.sin(m * theta)
                R[2*i+1,2*i, m] = np.sin(m * theta)
                R[2*i+1,2*i+1, m] = np.cos(m * theta)
        return R

    def forward(self, x):
        """
        x: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        out: [BATCH_SIZE x CONTEXT_WINDOW x EMBEDDING_DIM]
        """
        # q = R_m W_q x_m
        # okay, just get the embedding for the mth position

        # this is the mth embedding across all the batches
        q = x[0, :, :] @ self.w_q # query embedding
        # the thing i'm expecting should be the rotated query matrix,
        # but not across all m, just for the mth position


config = {
    "batch_size": 1,
    "context_window": 10,
    "d_model": 512,
}
r = RotarySelfAttention(config)
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))

r(batch).shape

torch.Size([512, 512, 512])

In [None]:
# Llama is transformer. what makes it different
# prenormalization at each sub-layer using RMSNorm
# SwiGLU activation function with 2/3*4d as dimension
# Rotary embeddings instead of absolute positional embeddings
class LlamaLayer(nn.Module):
    """
    """
    def __init__(self, config = {
        "num_heads": 5,
        "embed_dim": 100, # how many dimensions in the word embedding
    }):
        super().__init__()
        # multi attention head with pre-normalization
        self.multihead_attention = nn.MultiheadAttention(**config)

In [44]:

class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embeddings = nn.Embedding(config.max_len, config.d_model)
        self.rotary_pos_embeddings = RotaryEmbedding(config.d_model)
        self.layers = nn.ModuleList([LlamaLayer(config) for _ in range(config.num_layers)])
        self.norm = RMSNorm(config.d_model)
        self.fc = nn.Linear(config.d_model, config.vocab_size)

    def forward(self, x, pos):
        x = self.embeddings(x)
        pos = self.pos_embeddings(pos)
        x = x + pos
        x = self.rotary_pos_embeddings(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = self.fc(x)
        return x