#### !!!! DO NOT RUN THIS FIRST CELL UNLESS YOU HAVE THE SAME VENV PATH ISSUE THAT I DO

In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/learning_medusa/venv/lib/python3.11/site-packages')

# Ok now start

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
import random

#### !!!! ONLY FOR APPLE SILICON

In [3]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [4]:
# hyperparameters
b = 24 # how many independent sequences will we process in parallel?
t = 128 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
lr = 3e-4
eval_iters = 20
d = 128
h = 8
l = 8
dropout = 0.2
l2 = 0.01

m = 5
medusa_discount = torch.tensor(0.8).to(device)

In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
v = len(chars)

In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [8]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t - m, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([torch.stack([data[i+1+j:i+t+1+j] for i in ix]) for j in range(m+1)])
    return x.to(device), y.to(device)

In [10]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss, medusa_logits = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [11]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4 * d),
            nn.ReLU(), 
            nn.Linear(4 * d, d),
            nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

In [12]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d, head_size, bias=False)
        self.query = nn.Linear(d, head_size, bias=False)
        self.value = nn.Linear(d, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(t, t)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        b,t,d = x.shape
        k = self.key(x)   # (b,t,d/h)
        q = self.query(x) # (b,t,d/h)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (b, t, d/h) @ (b, d/h, t) -> (b, t, t)
        wei = wei.masked_fill(self.tril[:t, :t] == 0, float('-inf')) # (b, t, t)
        wei = F.softmax(wei, dim=-1) # (b, t, t)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (b,t,d/h)
        out = wei @ v # (b, t, t) @ (b, t, d/h) -> (b, t, d/h)
        return out

In [13]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(h)])
        self.proj = nn.Linear(head_size * h, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.proj(torch.cat([head(x) for head in self.heads], dim=-1)))

In [14]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d, h):
        # d: embedding dimension, h: the number of heads we'd like
        super().__init__()
        head_size = d // h
        self.sa = MultiHeadAttention(h, head_size)
        self.ffwd = FeedFoward(d)
        self.ln1 = nn.LayerNorm(d)
        self.ln2 = nn.LayerNorm(d)

    def forward(self, x):
        return x + self.ffwd(self.ln2(x + self.sa(self.ln1(x))))

In [15]:
class snake(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(d,d)
        self.relu = nn.ReLU() # actual paper uses SiLU bc they build off Llama
        self.w2 = nn.Linear(d,v)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w2(self.dropout(self.relu(self.w1(x))+x)) # outputs logits shape (b,t,v)

In [16]:
class medusaGPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, d)
        self.position_embedding_table = nn.Embedding(t, d)
        self.blocks = nn.Sequential(*[Block(d, h) for _ in range(l)])
        self.ln_f = nn.LayerNorm(d) # final layer norm
        self.lm_head = nn.Linear(d, v)
        
        # Create a list of Medusa heads
        self.medusa_heads = nn.ModuleList([snake() for _ in range(m)])
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, medusa_targets=None, verbose=False):
        b, t = idx.shape
        
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        x = self.ln_f(self.blocks(pos_emb + self.token_embedding_table(idx)))
        
        logits = self.lm_head(x) # (b,t,d)@(d,v)=(b,t,v)
        
        # Apply each snake head to x and store the results
        medusa_logits = torch.stack([head(x) for head in self.medusa_heads], dim=0)
        
        if targets is None:
            loss = None
            medusa_loss = None
        else:
            m, b, t, v = medusa_logits.shape
            logits = logits.view(b*t, v)
            targets0 = targets[0].view(b*t)
            loss = F.cross_entropy(logits, targets0)
            
            medusa_loss = torch.stack([F.cross_entropy(medusa_logits[i].view(b*t, v), targets[i+1].view(b*t)) * medusa_discount**(i+1) for i in range(m)])
            
            loss = loss + medusa_loss.sum()

        return logits, loss, medusa_logits

    def generate_gpt(self, idx, max_new_tokens, temperature=1.0):
        # idx is (b, t) array of indices in the current context
        #assert temperature >= 0
        
        for _ in range(max_new_tokens):
            # get the predictions
            logits, _, __ = self(idx[:, -t:])
            
            # apply softmax to get probabilities
            probs = F.softmax(logits[:, -1, :] / (temperature+1e-10), dim=-1) # (b, d)
            
            idx = torch.cat((idx, torch.multinomial(probs, num_samples=1)), dim=1)
            
        return idx

# Load a saved model

In [17]:
# Assuming `MyModel` is the class of your model
model = medusaGPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/medusa_b24_t128_d128_h8_l8_lr0.0003_drop0.2_l2-0.01_m5_mdiscount0.80_2024-01-25|23-31-12.pth'))
# this is the better of the two models i trained
# however the extra medusa heads are near useless

# If you only plan to do inference, switch to evaluation mode
model.eval()

medusaGPT(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(128, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
   

## Regular GPT Inference

So in theory this should be the slowest but really it's not much worse bc of how memory inefficient medusa's ugly sisters are

In [18]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output = model.generate_gpt(context_tensor, max_new_tokens=250)
print(decode(output[0].tolist()))

JULIET:
O Romeo, Romeo! wherefore art thou Rome, bounde be as and hownd neep, comal and,-thest my honown go, with tear,
Burking me whond mine statherse before, I's granchists man thilder dest arm,
As then, and and me is e's blonext and on youty.

VOLVOLUMNERCENIUS:
Soundes, our met I hellowish
CPU times: user 20.6 s, sys: 1.14 s, total: 21.7 s
Wall time: 21.3 s


In [19]:
print("tokens per second: ", 250/21.3)

tokens per second:  11.737089201877934


# Medusa's first sister Stheno (the aggressive one)

ChatGPT said:
Her name translates to "strength" or "forceful". Stheno was the eldest and most fierce of the sisters, known for her strength and ferocity. 

So I guess since this generation is strictly doing greedy decoding i'll name it after her

In [18]:
def generate_Stheno(model, idx, max_runs):
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    
    logits, loss, mlogits = model(idx[:, -t:])
    mlogits = mlogits[...,-1,:].squeeze(dim=1)
    
    idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t()
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2)
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) 

    tok_per_inf = [1]
    
    for _ in range(max_runs-1): 
        
        logits, loss, mlogits = model(torch.cat((idx, idx_m_prev), dim=1)[:, -t:]) 
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) 
        
        match_tensor = (idx_m_prev == idx_ntp[:,-(m+1):-1]).int()
        zero_positions = torch.cat((1 - match_tensor, torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype, device=device)), dim=1).argmax(dim=1)
        zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)
        range_tensor = torch.arange(match_tensor.size(1), device=device).unsqueeze(0).expand_as(match_tensor)
        mask = range_tensor < zero_positions.unsqueeze(1)
        result = (match_tensor * mask).sum(dim=1).item()

        tok_per_inf.append(result+1)
        
        idx_ntp = idx_ntp[:,-1-m+result].unsqueeze(dim=0)
        idx = torch.cat((idx, idx_m_prev[:,:result], idx_ntp),dim=1)
        
        mlogits = mlogits[...,-1-m+result,:].squeeze(dim=1)
        idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t()
        
    return idx, tok_per_inf

In [19]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Stheno(model, context_tensor, max_runs=250)
print(decode(output[0].tolist()))

JULIET:
O Romeo, Romeo! wherefore art thou Rome there there there thank the my honother so thild thinks thinks ther than think the streep think ther thinks ther than there thinks think there than the stand and thonger think ther think ther think the stand thonger th there thinks think ther than the stand and thonger think ther think ther think the stand thonger th there thinks think ther than the stand and thonger think t
CPU times: user 18.2 s, sys: 994 ms, total: 19.2 s
Wall time: 18.7 s


In [20]:
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print("tokens per second: ", sum(tok_per_inf)/18.7)

tokens per inference:  1.524
tokens per second:  20.37433155080214


#### even tho Stheno is faster, notice that it's restricted to greedy decoding which means the output quality is lower. If you compare the two passages you'll see that the one above has far less of a problem with repetition

# Medusa's second sister Euryale (the explorative one)

we'll see how this goes



ChatGPT said:

Her name means "far-roaming" in Greek. Euryale was known for her loud crying or bellowing. If your architecture is meant to explore a wide range of possibilities or to "roam" extensively through a dataset, the name Euryale might be suitable. Additionally, if your architecture involves a broad or far-reaching search strategy or is notable for 'broadcasting' its findings extensively (analogous to loud crying), Euryale could be an apt choice.

The goal here with Euryale is effectively to bring probabalistic decoding back to Stheno. My hope is that this will only require more ram and not result in any added latency, but we'll see. The basic idea is to use topk and then randomly or probability-wise select from candidate sequences instead of greedy decoding. If anything i think the re-incorporation of topk results will *maybe* add a speed increase (although likely not as fast as the attention-based mechanism used in actual Medusa)

In [18]:
def combinations(tensor):
    m, k = tensor.shape

    mesh_indices = torch.meshgrid([torch.arange(k) for _ in range(m)][::-1], indexing="ij")

    combinations = torch.cat([tensor[m-1-i][mesh_indices[i]].unsqueeze(0) for i in range(m)], dim=0)

    return combinations.T.reshape(-1, m).flip(dims=[1])

def compare(A,B):
    i,j,k = A.shape
    
    match_tensor = (A == B).int() 
    
    padded_tensor = torch.cat((1 - match_tensor, torch.ones((i,j,1), dtype=match_tensor.dtype, device=device)), dim=-1)
    
    zero_positions = padded_tensor.argmax(dim=-1)
    zero_positions[zero_positions >= k] = k
    
    range_tensor = torch.arange(m, device=device).unsqueeze(0).expand_as(match_tensor)
    
    mask = range_tensor < zero_positions.unsqueeze(-1)
    
    return (match_tensor * mask).sum(dim=-1)

In [19]:
def generate_Euryale(model, idx, max_runs, k=2):
    logits, loss, mlogits = model(idx[:, -t:])
    
    mlogits = mlogits[...,-1,:].squeeze(dim=1)
    
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2)
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1)
    
    tok_per_inf = [1]
    
    for _ in range(max_runs-1): 
        
        idx_m_topk = torch.topk(mlogits, k, dim=-1, largest=True).indices
        mcomb = combinations(idx_m_topk) 
        
        idx_rep = idx.repeat(k**m,1) 
        
        logits, loss, mlogits = model(torch.cat((idx_rep, mcomb), dim=1)[:, -t:])
        
        idx_ntp_topk = torch.topk(logits, k, dim=-1, largest=True).indices 
        
        idx_check = idx_ntp_topk[:,-(m+1):-1,:] 
        result = compare(mcomb.unsqueeze(0).repeat(k**m,1,1),torch.stack([combinations(idx_check[i]) for i in range(idx_check.shape[0])]))
        
        max_val = torch.max(result).item()
        tok_per_inf.append(max_val+1)
        
        if random.choice([True, False]):
            max_idx_row = torch.max(result,1).indices[0].item() # most likely
            max_idx_col = len(result[max_idx_row]) - 1 - torch.argmax(result[max_idx_row].flip(0)).item() # least likely
        elif random.choice([True, False]):
            max_idx_row = (len(result) - 1 - torch.max(result.flip(dims=[0]),1).indices)[0].item() # least likely
            max_idx_col = torch.argmax(result[max_idx_row]).item() # most likely
        else:
            max_idx_row = torch.max(result,1).indices[0].item() # most likely
            max_idx_col = torch.argmax(result[max_idx_row]).item() # most likely
        
        idx_m = mcomb[max_idx_row, :max_val].unsqueeze(0) # (k^m,m) -> (1,max_val)
        
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2)[max_idx_row, -1-m+max_val].unsqueeze(0).unsqueeze(0)
                
        idx = torch.cat((idx, idx_m, idx_ntp),dim=1) 
        
        mlogits = mlogits[:,max_idx_row,-1-m+max_val,:].unsqueeze(dim=1).squeeze(dim=1)        
            
    return idx, tok_per_inf

In [22]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Euryale(model, context_tensor, max_runs=250, k=2)
print(decode(output[0].tolist()))

JULIET:
O Romeo, Romeo! wherefore art thou Rome there thet  there therefol ,offfort  the me the myself and to the mare  and thonger 
Th   the  herest and to the  honown ther  thinde 
The  that rece theter  thatger think  thet thanks theredit ther  thander th thaee
Toetend th the  th thingsrer th things  thet than thate though , and to the savow th  the  sorry sad think  and think  the stand thonger th thoee thinks
Th the than to  the   thing  of  to een think  ofronder thi eaar then 
Th   to eeek the   of thieee th the th there
CPU times: user 24.4 s, sys: 4.54 s, total: 28.9 s
Wall time: 35.3 s


In [23]:
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print("tokens per second: ", sum(tok_per_inf)/35.3)

tokens per inference:  1.956
tokens per second:  13.852691218130312


### Aight so sometimes when I run this the output looks good and it's fast, but other times the output looks like this and it's not significantly faster than NTP. Lotts leaving things up to chance here. The great thing about real medusa is that it looks more like Stheno in terms of speed and also doesn't suffer from the probability problem of either of the sisters

I'd also like to note that there's no use in trying a value for k other than 2

And I didn't even use basic speedup methods like KV caching. Idk if i'm right but i think using those might disproportionately help my Euryale over NTP