In [3]:
import torch
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import time

## Model Configuration

In [316]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048

config = ModelArgs()
print(config)

ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=None, vocab_size=-1, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=2048, rope=None)


In [None]:
from typing import Optional

class MyClass:
    def __init__(self, other: Optional['abc']):
        self.other = other

class OtherClass:
    pass

# Now you can create an instance of MyClass with an instance of OtherClass
my_instance = MyClass(OtherClass())

## Prepare Data

In [816]:
# simple tokenization by characters
def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])


lines = open('./data/Shakespeare.txt', 'r').read()
vocab = sorted(list(set(lines)))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}
dataset = torch.tensor(encode(lines), dtype=torch.int8)
print(f'Sentences: {dataset.shape[0]}')

MASTER_CONFIG = {
    "vocab_size": len(vocab),
}

def get_batches(data, split, batch_size, context_window, config=MASTER_CONFIG):
    train = data[:int(.8 * len(data))]
    val = data[int(.8 * len(data)): int(.9 * len(data))]
    test = data[int(.9 * len(data)):]

    if split == 'train':
        batch_data = train
    elif split == 'test':
        batch_data = test
    else:
        batch_data = val

    # pick random starting points
    ix = torch.randint(0, batch_data.size(0) - context_window - 1, (batch_size,))
    x = torch.stack([batch_data[i:i+context_window] for i in ix]).long()
    y = torch.stack([batch_data[i+1:i+context_window+1] for i in ix]).long()
    return x, y

Sentences: 1115394


## Support Functions

In [817]:
@torch.no_grad()  # don't compute gradients for this function
def evaluate_loss(model, config=MASTER_CONFIG):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = []
        for _ in range(10):
            xb, yb = get_batches(dataset, split, config['batch_size'], config['context_window'])
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)
    model.train()
    return out

def train(model, optimizer, scheduler=None, config=MASTER_CONFIG, print_logs=False):
    losses = []
    start_time = time.time()
    for epoch in range(config['epochs']):
        optimizer.zero_grad()

        xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])
        logits, loss = model(xs, targets=ys)
        loss.backward()
        optimizer.step()

        if scheduler:
            scheduler.step()

        if epoch % config['log_interval'] == 0:
            batch_time = time.time() - start_time
            x = evaluate_loss(model)
            losses += [x]
            if print_logs:
                print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (config['epochs'] - epoch)/config['log_interval'] :.3f}")
            start_time = time.time()

            if scheduler:
                print("lr: ", scheduler.get_lr())

    # print(pd.DataFrame(losses))
    print("validation loss: ", losses[-1]['val'])
    return pd.DataFrame(losses).plot()

## Define Models

### RMS Normalization 

- [Paper](https://arxiv.org/pdf/1910.07467.pdf)
- [Reference implementation](https://github.com/facebookresearch/llama/blob/54d44631054deae836aec8ceff92dcf8f20ca9e7/llama/model.py#L34)

In [None]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """        
        return self._norm(x.float()).type_as(x) * self.weight

### RoPE

- [Paper](https://arxiv.org/pdf/2104.09864.pdf)
- [Reference Implementation](https://github.com/facebookresearch/llama/blob/dccf644213a2771a81fc4a754eed9623ea7f8444/llama/model.py#L80)

In [230]:
class RoPE:
    def __init__(self, dim: int, max_seq_len: int, theta: float = 10000.0):
        """
        Precompute the frequency tensor for complex exponentials (cis, defined as 'm*theta_i' in the paper) 
        with given dimensions.

        Calculates a frequency tensor with complex exponentials using the given dimension 'dim'
        and the max sequence length. The 'theta_base' parameter scales the frequencies.
        The returned tensor contains complex values in complex64 data type.

        Args:
            dim (int): Dimension of the frequency tensor.
            max_seq_len (int): Max sequence length.
            theta_base (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        """
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        freqs = torch.outer(torch.arange(max_seq_len), freqs).float()
        self.freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply rotary embeddings to input tensors using the given frequency tensor.

        This function first reshapes the frequency tensor to have the same shape as the target tensor 'x'
        for the purpose of broadcasting the frequency tensor during element-wise operations. Then, it applies 
        rotary embeddings to 'x' tensor using frequency tensor 'freqs_cis'.         
        """
        assert 1 < x.ndim

        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        assert self.freqs_cis.shape == (x_complex.shape[1], x_complex.shape[-1])

        shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x_complex.shape)]
        freqs_cis = self.freqs_cis.view(*shape)  
                
        x_real = torch.view_as_real(x_complex * freqs_cis).flatten(-2)
        
        return x_real.type_as(x)

#### RoPE Test

In [313]:
dim = 128
max_seq_len = 256

def get_rotary_matrix(context_window, embedding_dim):
    R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
    for position in range(context_window):
        for i in range(embedding_dim//2):
            theta = 10000. ** (-2.*i / embedding_dim)
            m_theta = position * theta
            R[position, 2*i,2*i] = np.cos(m_theta)
            R[position, 2*i,2*i+1] = - np.sin(m_theta)
            R[position, 2*i+1,2*i] = np.sin(m_theta)
            R[position, 2*i+1,2*i+1] = np.cos(m_theta)
    return R

R = get_rotary_matrix(max_seq_len, dim)

X= torch.ones(1, max_seq_len, dim)
rope = RoPE(dim=dim, max_seq_len=max_seq_len)
X1 = rope(X)
X2 = (R @ X.unsqueeze(-1)).flatten(-2)

assert(X1.allclose(X2, atol=1e-3))


### Attention

In [818]:
class AttentionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.w_q = nn.Linear(config['d_model'], config['d_model'], bias=False)
        self.w_k = nn.Linear(config['d_model'], config['d_model'], bias=False)
        self.w_v = nn.Linear(config['d_model'], config['d_model'], bias=False)

        self.R = get_rotary_matrix(config['context_window'], config['d_model'])

    def forward(self, x):
        print(f'x: {x[0,:,:5]}')
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)

        # print(f'q:{q[0, :,:5]}')
        # q[:,:-1,:] = torch.zeros(q[:,:-1,:].shape)
        # print(f'q:{q[0, :,:5]}')
        
        if 'inference' in self.config and self.config['inference']:  # implement KV cache            
            if x.shape[-2] > 1:  #reset kv cache for new sentence
                self.k = k
                self.v = v
            else:
                self.k = torch.concat([self.k[:, -self.config['context_window'] + 1:], k], dim = -2)
                self.v = torch.concat([self.v[:, -self.config['context_window'] + 1:], v], dim = -2)            
                k, v = self.k, self.v
        
        _, m, d = k.shape
        # print(f'q:{q.shape}, k:{k.shape}, v:{v.shape}, R:{self.R[m - q.shape[-2]:m].shape}')
        q_rotated = (torch.bmm(q.transpose(0,1), self.R[m - q.shape[-2]:m])).transpose(0,1)
        k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
        print(f'q_rotated: {q_rotated[0][-1][:5]}')

        activations = F.scaled_dot_product_attention(
            q_rotated, k_rotated, v, dropout_p =.1, is_causal=True
        )

        return activations


class MaskedMultiheadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.heads = nn.ModuleList([
            AttentionHead(config) for _ in range(config['n_heads'])
        ])
        self.linear = nn.Linear(config['n_heads'] * config['d_model'], config['d_model'])
        self.dropout = nn.Dropout(.1)

    def forward(self, x):
        heads = [h(x) for h in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.linear(x)
        x = self.dropout(x)
        return x

In [819]:
from collections import OrderedDict

class SwiGLU(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linear_gate = nn.Linear(size, size)
        self.linear = nn.Linear(size, size)

        self.beta = torch.ones(1, requires_grad=True)

    def forward(self, x): 
        swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
        out = swish_gate * self.linear(x)
        return out

class LlamaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.rms = RMSNorm((config['context_window'], config['d_model']))

        self.attention = RoPEMaskedMultiheadAttention(config)
        self.feedforward = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            SwiGLU(config['d_model']),
        )

    def forward(self, x):
        x = self.rms(x) # rms pre-normalization
        x = x + self.attention(x)

        x = self.rms(x) # rms pre-normalization
        x = x + self.feedforward(x)
        return x

class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config['vocab_size'], config['d_model'])
        self.llama_blocks = nn.Sequential(
            OrderedDict([(f"llama_{i}", LlamaBlock(config)) for i in range(config['n_layers'])])
        )

        self.ffn = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            SwiGLU(config['d_model']),
            nn.Linear(config['d_model'], config['vocab_size']),
        )

        print("model params:", sum([m.numel() for m in self.parameters()]))

    def forward(self, idx, targets=None):
        print(f'x:{idx}')
        x = self.embeddings(idx)
        x = self.llama_blocks(x)
        logits = self.ffn(x)

        if targets is None:
            return logits

        else:
            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
            return logits, loss

## Training

In [820]:
%%time

MASTER_CONFIG.update({
    'epochs': 10000,
    'batch_size': 32,
    'd_model': 128,
    'n_heads': 8,
    'n_layers': 4,
    'context_window': 16,
    "log_interval": 100,
})

# llama = Llama(MASTER_CONFIG)
# optimizer = torch.optim.Adam(llama.parameters())
# train(llama, optimizer, config=MASTER_CONFIG, print_logs=True)

# # Save
# torch.save({'model_state_dict': llama.state_dict()}, "./checkpoint/llama_kv.pth")

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.2 µs


## Generate

In [821]:
# Initialize the model and optimizer
llama_infer = Llama(MASTER_CONFIG)

# Load
checkpoint = torch.load("./checkpoint/llama_kv.pth")
llama_infer.load_state_dict(checkpoint['model_state_dict'])

model params: 2370241


<All keys matched successfully>

In [822]:
def generate(model, config=MASTER_CONFIG, max_new_tokens=30):
    idx = torch.zeros(config['batch_size'], 2).long()
    # config = config.copy()

    config['inference'] = True
    for i in range(max_new_tokens):
        if i == 0:
            logits = model(idx)
        else:
            # logits = model(idx[:, -1].unsqueeze(1))
            logits = model(idx[:, -config['context_window']:])
        print(f'logits:{logits[:, -1, :5]}')
            
        last_time_step_logits = logits[:, -1, :]            # all the batches (1), last time step, all the logits
        p = F.softmax(last_time_step_logits, dim=-1)        # softmax to get probabilities
        idx_next = torch.multinomial(
            p, num_samples=1
        )                                                   # sample from the distribution to get the next token
        idx = torch.cat([idx, idx_next], dim=-1)            # append to the sequence
    
    return [decode(x) for x in idx.tolist()]

In [823]:
%%time

MASTER_CONFIG.update({
    'batch_size': 1,
})

torch.manual_seed(123)
print(generate(llama_infer, MASTER_CONFIG, 10)[0])

x:tensor([[0, 0]])
x: tensor([[-0.5340,  0.3109,  0.3360, -1.1308, -0.1879],
        [-0.8379,  0.4162,  0.4345, -2.0445, -0.2723]],
       grad_fn=<SliceBackward0>)
q_rotated: tensor([ 4.4426, -1.5656, -2.8417, -0.4424,  0.2025], grad_fn=<SliceBackward0>)
x: tensor([[-0.5340,  0.3109,  0.3360, -1.1308, -0.1879],
        [-0.8379,  0.4162,  0.4345, -2.0445, -0.2723]],
       grad_fn=<SliceBackward0>)
q_rotated: tensor([ 0.6952,  0.1257,  0.6421,  0.6120, -0.9635], grad_fn=<SliceBackward0>)
x: tensor([[-0.5340,  0.3109,  0.3360, -1.1308, -0.1879],
        [-0.8379,  0.4162,  0.4345, -2.0445, -0.2723]],
       grad_fn=<SliceBackward0>)
q_rotated: tensor([ 0.4299,  0.8142, -0.3514,  0.6077,  1.0348], grad_fn=<SliceBackward0>)
x: tensor([[-0.5340,  0.3109,  0.3360, -1.1308, -0.1879],
        [-0.8379,  0.4162,  0.4345, -2.0445, -0.2723]],
       grad_fn=<SliceBackward0>)
q_rotated: tensor([ 0.6608,  0.0423,  0.1577,  0.7558, -0.6731], grad_fn=<SliceBackward0>)
x: tensor([[-0.5340,  0.3109,

In [824]:
decode([0, 0, 32, 56, 43, 47])

'\n\nTrei'

In [825]:
decode([0,  0, 32, 46, 43,  1, 57, 54, 43, 39, 49])

'\n\nThe speak'

In [10]:
def self_attention_with_kv_cache(query, key, value, kv_cache=None):
    if kv_cache is None:
        attention_scores = torch.matmul(query, key.transpose(1,2))        
        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)
        attention_weights.to(attention_weights)
        
        # Scale attention weights by sqrt(d_k)
        d_k = key.size(-1)
        attention_weights = attention_weights * (d_k ** -0.5)
        context_vector = torch.matmul(attention_weights, value)
        kv_cache = {
            "key": key,
            "value": value
        }
    else:
        key = torch.cat([kv_cache["key"], key], dim=-1)
        value = torch.cat([kv_cache["value"], value], dim=-1)
        attention_scores = torch.matmul(query, key.transpose(1,2))
        attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)
        # Scale attention weights by sqrt(d_k)
        d_k = key.size(-1)
        attention_weights = attention_weights * (d_k ** -0.5)
        context_vector = torch.matmul(attention_weights, value)
        kv_cache = {
            "key": key,
            "value": value
        }

    return context_vector, kv_cache

def scaled_dot_product_attention(query, key, value):
    attention_scores = torch.matmul(query, key, transpose_b=True)
    # Apply softmax attention weights
    attention_weights = torch.nn.functional.softmax(attention_scores, dim=-1)

    # Scale attention weights by sqrt(d_k)
    d_k = key.size(-1)
    attention_weights *= (d_k ** -0.5)

    # Compute context vector
    context_vector = torch.matmul(attention_weights, value)

    return attention_weights, context_vector

In [11]:
query = torch.randn(10, 50, 64)
key = torch.randn(10, 100, 64)
value = torch.randn(10, 100, 64)

# Perform self-attention without KV caching
context_vector_no_cache, _ = self_attention_with_kv_cache(query, key, value)

# Perform self-attention with KV caching
context_vector_with_cache, kv_cache = self_attention_with_kv_cache(query, key, value, kv_cache={})

KeyError: 'key'