In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/learning_medusa/venv/lib/python3.11/site-packages')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time

In [3]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [4]:
# hyperparameters
b = 24 # how many independent sequences will we process in parallel?
t = 128 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
lr = 3e-4
eval_iters = 20
d = 128
h = 8
l = 8
dropout = 0.2
l2 = 0.01

medusa_headcount = 3
medusa_discount = torch.tensor(0.8).to(device)

In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [8]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t - medusa_headcount, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([torch.stack([data[i+1+j:i+t+1+j] for i in ix]) for j in range(medusa_headcount+1)])
    return x.to(device), y.to(device)

In [10]:
x,y = get_batch('train')
print("x[0] ", x[0].shape, "\n", x[0])
print("y[:,0,...] ", y[:,0,...].shape, "\n", y[:,0,...])

x[0]  torch.Size([128]) 
 tensor([ 0, 13, 57,  1, 61, 43, 56, 43,  1, 53, 59, 56,  1, 17, 52, 45, 50, 39,
        52, 42,  1, 47, 52,  1, 56, 43, 60, 43, 56, 57, 47, 53, 52,  1, 46, 47,
        57,  6,  0, 13, 52, 42,  1, 46, 43,  1, 53, 59, 56,  1, 57, 59, 40, 48,
        43, 41, 58, 57,  5,  1, 52, 43, 62, 58,  1, 42, 43, 45, 56, 43, 43,  1,
        47, 52,  1, 46, 53, 54, 43,  8,  0,  0, 19, 30, 17, 17, 26, 10,  0, 35,
        43, 50, 50,  6,  1, 46, 43,  1, 47, 57,  1, 45, 53, 52, 43, 11,  1, 39,
        52, 42,  1, 61, 47, 58, 46,  1, 46, 47, 51,  1, 45, 53,  1, 58, 46, 43,
        57, 43], device='mps:0')
y[:,0,...]  torch.Size([4, 128]) 
 tensor([[13, 57,  1, 61, 43, 56, 43,  1, 53, 59, 56,  1, 17, 52, 45, 50, 39, 52,
         42,  1, 47, 52,  1, 56, 43, 60, 43, 56, 57, 47, 53, 52,  1, 46, 47, 57,
          6,  0, 13, 52, 42,  1, 46, 43,  1, 53, 59, 56,  1, 57, 59, 40, 48, 43,
         41, 58, 57,  5,  1, 52, 43, 62, 58,  1, 42, 43, 45, 56, 43, 43,  1, 47,
         52,  1, 46, 5

In [11]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss, medusa_logits = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [12]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4 * d),
            nn.ReLU(), 
            nn.Linear(4 * d, d),
            nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

In [13]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d, head_size, bias=False)
        self.query = nn.Linear(d, head_size, bias=False)
        self.value = nn.Linear(d, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(t, t)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        b,t,d = x.shape
        k = self.key(x)   # (b,t,d/h)
        q = self.query(x) # (b,t,d/h)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (b, t, d/h) @ (b, d/h, t) -> (b, t, t)
        wei = wei.masked_fill(self.tril[:t, :t] == 0, float('-inf')) # (b, t, t)
        wei = F.softmax(wei, dim=-1) # (b, t, t)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (b,t,d/h)
        out = wei @ v # (b, t, t) @ (b, t, d/h) -> (b, t, d/h)
        return out

In [14]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(h)])
        self.proj = nn.Linear(head_size * h, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [15]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d, h):
        # d: embedding dimension, h: the number of heads we'd like
        super().__init__()
        head_size = d // h
        self.sa = MultiHeadAttention(h, head_size)
        self.ffwd = FeedFoward(d)
        self.ln1 = nn.LayerNorm(d)
        self.ln2 = nn.LayerNorm(d)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [16]:
class snake(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(d,d)
        self.relu = nn.ReLU() # actual paper uses SiLU bc they build off Llama
        self.w2 = nn.Linear(d,v)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w2(self.dropout(self.relu(self.w1(x))+x)) # outputs logits shape (b,t,v)

In [17]:
class medusaGPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, d)
        self.position_embedding_table = nn.Embedding(t, d)
        self.blocks = nn.Sequential(*[Block(d, h) for _ in range(l)])
        self.ln_f = nn.LayerNorm(d) # final layer norm
        self.lm_head = nn.Linear(d, v)
        
        # Create a list of Medusa heads
        self.medusa_heads = nn.ModuleList([snake() for _ in range(medusa_headcount)])
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, medusa_targets=None, verbose=False):
        b, t = idx.shape

        # idx and targets are both (b,t) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        x = tok_emb + pos_emb # (b,t,d)
        x = self.ln_f(self.blocks(x)) # (b,t,d) -> (b,t,d)
        
        logits = self.lm_head(x) # (b,t,d)@(d,v)=(b,t,v)
        
        # Apply each snake head to x and store the results
        medusa_logits = torch.stack([head(x) for head in self.medusa_heads], dim=0)
        
        if verbose:
            print("logits: ", logits.shape, logits)
            print("medusa_logits: ", medusa_logits.shape, medusa_logits)
        
        if targets is None:
            loss = None
        else:
            m, b, t, v = medusa_logits.shape
            logits = logits.view(b*t, v)
            targets0 = targets[0].view(b*t)
            loss = F.cross_entropy(logits, targets0)

            # todo: make this all one line
            mloss = []
            for i in range(m):
                mloss.append(F.cross_entropy(medusa_logits[i].view(b*t,v), targets[i+1].view(b*t))*medusa_discount**(i+1))
            medusa_loss = torch.stack(mloss)
            # this might work but i haven't tested it \/
            #medusa_loss = torch.stack([F.cross_entropy(medusa_logits[i].view(b*t, v), targets[i+1].view(b*t)) * medusa_discount**(i+1) for i in range(m)])
            
            loss = loss + medusa_loss.sum()

        return logits, loss, medusa_logits

    def generate_gpt(self, idx, max_new_tokens, temperature=1.0):
        # idx is (b, t) array of indices in the current context
        #assert temperature >= 0
        
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -t:]
            
            # get the predictions
            logits, loss, mlogits = self(idx_cond)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            
            # scale logits by the temperature
            logits = logits / (temperature+1e-10)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        
        return idx

# Training

if you don't want to do your own training just scroll down

In [18]:
model = medusaGPT()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e3, 'K parameters')

1691.14 K parameters


In [19]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)

In [20]:
start_time = time.time()
for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    # train
    logits, loss, mlogits = model(xb, yb)

    # L1 Regularization
    l1_norm = sum(p.abs().sum() for p in model.parameters())

    # Total loss
    total_loss = loss + 0.01 * l1_norm
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

step 0: train loss 11.8995, val loss 11.9120, time elapsed: 1.95 seconds
step 100: train loss 8.8519, val loss 8.8598, time elapsed: 46.52 seconds
step 200: train loss 8.5150, val loss 8.5164, time elapsed: 90.72 seconds
step 300: train loss 8.3617, val loss 8.3665, time elapsed: 134.74 seconds
step 400: train loss 8.2472, val loss 8.2383, time elapsed: 178.67 seconds
step 500: train loss 8.1164, val loss 8.1313, time elapsed: 222.43 seconds
step 600: train loss 7.9587, val loss 8.0067, time elapsed: 266.19 seconds
step 700: train loss 7.8637, val loss 7.9270, time elapsed: 309.95 seconds
step 800: train loss 7.7623, val loss 7.7944, time elapsed: 353.73 seconds
step 900: train loss 7.6077, val loss 7.7277, time elapsed: 397.29 seconds
step 1000: train loss 7.5329, val loss 7.6664, time elapsed: 441.00 seconds
step 1100: train loss 7.4500, val loss 7.5704, time elapsed: 484.66 seconds
step 1200: train loss 7.3449, val loss 7.5147, time elapsed: 528.13 seconds
step 1300: train loss 7.26

#### save the trained model

In [23]:
torch.save(model.state_dict(), f'models/medusa_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2:{l2}_m{medusa_headcount}_mdiscount{medusa_discount:.2f}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Load a saved model

In [18]:
# Assuming `MyModel` is the class of your model
model = medusaGPT().to(device)  # Initialize a model with the same architecture


# Load the saved state dictionary
model.load_state_dict(torch.load('models/medusa_b24_t128_d128_h8_l8_lr0.0003_drop0.2_m3_mdiscount0.8_2024-01-24|17-36-02.pth'))

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

medusaGPT(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(128, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
   

## Regular GPT Inference

for some reason if i try to run this before doing my later functions than the kernel crashes. I recommend restarting with a fresh kernel for those

In [29]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output = model.generate_gpt(context_tensor, max_new_tokens=218)
output_str = decode(output[0].tolist())
print(output_str)

JULIET:
O Romeo, Romeo! wherefore art thou Roman?

WARWICK:
The viley and whiled, then the name of this mantering.
Forst of when forfeing in our honour with the dukre
Ere it mock and along to my own Angelo;
Even the haresd than before my lord embove.
Gate speak,
CPU times: user 19.6 s, sys: 1.03 s, total: 20.7 s
Wall time: 20.4 s


# Medusa's first sister Stheno (the aggressive one)

ChatGPT said:
Her name translates to "strength" or "forceful". Stheno was the eldest and most fierce of the sisters, known for her strength and ferocity. 

So I guess since this generation is strictly doing greedy decoding i'll name it after her

In [20]:
def generate_Stheno(model, idx, max_runs):
    # Ensure idx is a single sequence
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    print("idx: ", idx.shape)
    
    # crop idx to the last block_size tokens
    input_cond = idx[:, -t:] # should this be just t???? previosly i had it as -(t+medusa_headcount)
    print("idx_cond should be size (1, t): ", input_cond.shape)
          
    # get the predictions
    logits, loss, mlogits = model(input_cond) # (1,t,v), int, and (medusa_headcount,1,t,v)
    print("logits should be (1,t,v): ", logits.shape)
    print("mlogits should be (medusa_headcount,1,t,v): ", mlogits.shape)
    
    # we only want the medusa preditions for the newest future timestep
    mlogits = mlogits[...,-1,:] # becomes (medusa_headcount,1,v)
    mlogits = mlogits.squeeze(dim=1) # (medusa_headcount,v)
    print("mlogits should be (medusa_headcount,v): ", mlogits.shape)
    
    # actual medusa does a whole proabability attention mask tree thing, but here we'll assume greedy decoding on snake heads
    idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (1,medusa_headcount)
    print("idx_m_prev should be (1,medusa_headcount): ", idx_m_prev.shape)
    # we name it _prev since it'll be used on the next loop
    
    # medusa requires greedy decoding for the first regular t+1'th token
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1, t)
    print("idx_ntp should be (1, t): ", idx_ntp.shape)
    
    # append sampled index to the running sequence
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) # (1,t)+(1,1) -> (1, t+1)
    print("idx should be (1, t+1): ", idx.shape)

    # keep track of how many tokens we get per model inference
    tok_per_inf = [1]
    
    for _ in range(max_runs-1): # -1 since one iteration was done above
        ###### I think this is where i begin the looped part. above stuff should be before the for loop
        
        # so now we have idx shape (1,t+1) of context tokens plus 1 predicted token
        # and idx_m_prev shape  (1,medusa_headcount) of speculative tokens to check against later & include some
        # want to know whether model actually would've predicted the tokens in idx_m so we need to run it with them & compare
        
        # first we construct a tensor composed of the initial context, the single ntp token, and then the speculative tokens
        input = torch.cat((idx, idx_m_prev), dim=1) # (1,t+1) & (1,medusa_headcount) -> (1,t+1+medusa_headcount)
        print("input should be (1,t+1+medusa_headcount): ", input.shape)
        
        # but since the model can't take more than t inputs
        # note: t is max content limit, but input might be shorter. if so that's fine
        input_cond = input[:, -t:] 
        # (1,t+1+medusa_headcount) -> (1,t) where first t-1-m are prior context, the t-m'th is ntp, & last m are candidates
        print("input_cond should be (1,t): ", input_cond.shape)
        
        # then we pass it in
        logits, loss, mlogits = model(input_cond) # (1,t,v), tensor of a single float, & (medusa_headcount,1,t,v)
        
        # we're just greedy decoding so there's no need to softmax
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1,t,v) -> (1, t)
        print("idx_ntp should be (1, t): ", idx_ntp.shape)
        
        # now for comparison's sake we need to ignore the indices that were part of the context
        # the regular ntp prediction can also be ignored
        # we just need to know the number of previous turn's medusa preds to keep
        idx_check = idx_ntp[:,-(medusa_headcount+1):-1] # (1,t) -> (1,medusa_headcount)
        print("idx_check should be (1,medusa_headcount): ", idx_check.shape, idx_check)
        
        # check whether they match
        match_tensor = (idx_m_prev == idx_check).int()
        #(1,medusa_headcount) of ints & (1,medusa_headcount) of ints -> (1,medusa_headcount) of 1's and 0's
        print("match_tensor should be (1,medusa_headcount) of 1's and 0's: ", match_tensor.shape, match_tensor)
        
        ########### then define our "definitely good" indices as those which are 1's for the similarity check
        # We invert the tensor so that all zeros & 1's flip
        # Find the first 1 in each row. The max function returns the first occurrence of the maximum value.
        # We add one extra zero at the end of each row to handle rows that contain no zeros.
        pad = torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype, device=device)
        print("pad: ", pad.shape, pad)
        padded_tensor = torch.cat((1 - match_tensor, pad), dim=1)
        print("padded_tensor: ", padded_tensor.shape, padded_tensor)
        zero_positions = padded_tensor.argmax(dim=1)
        
        # Adjust indices where the last position is selected (meaning there were no zeros)
        zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)
        print("zero_positions: ", zero_positions.shape, zero_positions)
        
        # Create a range tensor
        range_tensor = torch.arange(match_tensor.size(1), device=device).unsqueeze(0).expand_as(match_tensor)
        print("range_tensor: ", range_tensor.shape, range_tensor)
        
        # Create a mask where each element is 1 if it is before the first zero in its row
        mask = range_tensor < zero_positions.unsqueeze(1)
        print("mask: ", mask.shape, mask)
        
        # Apply the mask and sum along each row
        result = (match_tensor * mask).sum(dim=1).item()
        print("result: ", result)

        tok_per_inf.append(result+1)
        
        # so now i've got result which is an integer between 0 and medusa_headcount (inclusive)
        # i need to use it as an index on the actual medusa head outputs of interest and then append that to our sequence
        # but we also get to use the regular ntp prediction right after any of last run's accepted medusa head predictions
        print("idx: ", idx.shape, idx)
        idx_m_prev = idx_m_prev[:,:result]
        print("idx_m_prev should be (1,result): ", idx_m_prev.shape, idx_m_prev)
        idx_ntp = idx_ntp[:,-1-medusa_headcount+result].unsqueeze(dim=0)
        print("idx_ntp should be (1,1): ", idx_ntp.shape, idx_ntp)
        idx = torch.cat((idx, idx_m_prev, idx_ntp),dim=1)
        print("idx should be (1,t+result+1: ", idx.shape, idx)
        
        # so now we have idx of size (1,t+1+result)
        # and idx_m of size (1,medusa_headcount)
        # wait but there's a problem. If we didn't use the largest result=medusa_headcount then 
        #   idx_m is for items too far into the future. we need to pass forward an idx_m that's relevant
        # ngl i'm too tired to think abou tit so i'm assuming i can use the same index i used for idx_ntp
        mlogits = mlogits[...,-1-medusa_headcount+result,:].squeeze(dim=1) # (medusa_headcount,1,t,v) -> (medusa_headcount,1,v)
        print("mlogits: ", mlogits.shape)
        idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (medusa_headcount,1,v) -> (1,medusa_headcount)
        print("idx_m_prev: ", idx_m_prev.shape, idx_m_prev)
        # we name it _prev since it'll be used on the next loop
        
    return idx, tok_per_inf

In [30]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
output, tok_per_inf = generate_Stheno(model, context_tensor, max_runs=100)
output_str = decode(output[0].tolist())
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print(output_str)

idx:  torch.Size([1, 44])
idx_cond should be size (1, t):  torch.Size([1, 44])
logits should be (1,t,v):  torch.Size([1, 44, 65])
mlogits should be (medusa_headcount,1,t,v):  torch.Size([3, 1, 44, 65])
mlogits should be (medusa_headcount,v):  torch.Size([3, 65])
idx_m_prev should be (1,medusa_headcount):  torch.Size([1, 3])
idx_ntp should be (1, t):  torch.Size([1, 44])
idx should be (1, t+1):  torch.Size([1, 45])
input should be (1,t+1+medusa_headcount):  torch.Size([1, 48])
input_cond should be (1,t):  torch.Size([1, 48])
idx_ntp should be (1, t):  torch.Size([1, 48])
idx_check should be (1,medusa_headcount):  torch.Size([1, 3]) tensor([[51, 43, 53]], device='mps:0')
match_tensor should be (1,medusa_headcount) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  tor

In [23]:
sum(tok_per_inf)/len(tok_per_inf)

2.4

In [140]:
A = torch.randint(10, (1,t))
print(A)
print(A[:,-t:])
print(A[:,-(t+2):])

tensor([[9, 1, 8, 0, 5, 2, 8, 7]])
tensor([[9, 1, 8, 0, 5, 2, 8, 7]])
tensor([[9, 1, 8, 0, 5, 2, 8, 7]])


In [143]:
b=1
t = 5
medusa_headcount = 3
A = torch.randint(10, (b,t))
B = torch.ones((b,1))
C = torch.zeros((b,medusa_headcount))
D = torch.cat((A,B,C),dim=1)
print("A: ", A)
print("B: ", B)
print("C: ", C)
print("D: ", D)
print("D[:,-medusa_headcount:] ", D[:,-medusa_headcount:])

A:  tensor([[7, 8, 5, 3, 5]])
B:  tensor([[1.]])
C:  tensor([[0., 0., 0.]])
D:  tensor([[7., 8., 5., 3., 5., 1., 0., 0., 0.]])
D[:,-medusa_headcount:]  tensor([[0., 0., 0.]])


In [152]:
A = torch.tensor([[0,21,14,33]])
B = torch.tensor([[0,21,27,33]])
print(A)
print(B)
print(A[:,1:])
print(B[:,:-1])
match_tensor = (A == B).int()
print(match_tensor)

tensor([[ 0, 21, 14, 33]])
tensor([[ 0, 21, 27, 33]])
tensor([[21, 14, 33]])
tensor([[ 0, 21, 27]])
tensor([[1, 1, 0, 1]], dtype=torch.int32)


In [153]:
# Find the first zero in each row. The max function returns the first occurrence of the maximum value.
# We invert the tensor so the first zero becomes the max value in each row.
# We add one extra zero at the end of each row to handle rows that contain no zeros.
padded_tensor = torch.cat((1 - match_tensor, torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype)), dim=1)
zero_positions = padded_tensor.argmax(dim=1)

# Adjust indices where the last position is selected (meaning there were no zeros)
zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)

# Create a range tensor
range_tensor = torch.arange(match_tensor.size(1), device=match_tensor.device).unsqueeze(0).expand_as(match_tensor)

# Create a mask where each element is 1 if it is before the first zero in its row
mask = range_tensor < zero_positions.unsqueeze(1)

# Apply the mask and sum along each row
result = (match_tensor * mask).sum(dim=1)

print(result, result.dtype, result.shape)
print(result.item())


tensor([2]) torch.int64 torch.Size([1])
2


In [146]:
inverted_result = medusa_headcount - result
inverted_result

tensor([1])

In [160]:
# Example tensor of shape (1, m)
m = 5
tensor_a = torch.randn(1, m)
print("tensor A: ", tensor_a)

# Integer x
x = 0  # Can be any value from 0 to m

# Slicing tensor_a based on x
sliced_tensor_a = tensor_a[:, :x] if x > 0 else torch.tensor([], dtype=tensor_a.dtype)

# Another tensor to concatenate with
tensor_b = torch.randn(1, 2)  # Example tensor
print("tensor_b: ", tensor_b)

# Concatenating the tensors
result_tensor = torch.cat((tensor_b, sliced_tensor_a), dim=1)

print("Sliced Tensor A:", sliced_tensor_a)
print("Result Tensor:", result_tensor)


tensor A:  tensor([[ 0.1422, -0.4801,  1.8750,  0.3247,  1.3350]])
tensor_b:  tensor([[1.5440, 2.1109]])
Sliced Tensor A: tensor([])
Result Tensor: tensor([[1.5440, 2.1109]])


In [35]:
print(len(encode(input_str)))
match_tensor[:,len(encode(input_str))-2:]

27


tensor([[1]], device='mps:0', dtype=torch.int32)

tensor([[8, 4, 9, 7, 4, 1, 3, 3, 6, 9]])
tensor([[2, 9, 9, 5, 2, 3, 8, 9, 3, 8]])
tensor([[4, 9, 7, 4, 1, 3, 3, 6, 9]])
tensor([[2, 9, 9, 5, 2, 3, 8, 9, 3]])
tensor([[0, 1, 0, 0, 0, 1, 0, 0, 0]], dtype=torch.int32)


In [30]:
def generate(model, input_str, medusa=False, max_runs=64, temperature=1.0, verbose=False):
    # Encode the input string to a list of integers
    context = encode(input_str)

    # Convert the context to a PyTorch tensor, add a batch dimension, and send to the device
    context_tensor = torch.tensor([context], dtype=torch.long, device=device)

    # Generate output from the model
    if not medusa:
        output = model.generate_ntp(context_tensor, max_runs=max_runs, temperature=temperature, verbose=verbose)
        candidates = None
        output_list = None
    else:
        output, candidates, output_list = model.generate_medusa(context_tensor, max_runs=max_runs, temperature=temperature, verbose=verbose)

    # Convert the output to a list and decode it to a string
    output_str = decode(output[0].tolist())

    return output_str, candidates, output_list

In [31]:
prompt = "JULIET:\nO Romeo, Romeo! wherefore art thou R"

In [36]:
output, candidates, output_list = generate(m, prompt, temperature=0.5, verbose=False)
print(output)

JULIET:
O Romeo, Romeo! wherefore art thou RaZrE;y3uj nyI w:RC
h&olm?A
j&H--BYb.3fm!imk g uTidRf Mi hNifgu n


In [37]:
output, candidates, output_list = generate(m, prompt, medusa=True, temperature=0.5, verbose=False)
print(output)
print(decode(output_list[-1][0].tolist()))
print(candidates[0])
print(candidates[-1])
for i in range(len(candidates)):
    string = decode(output_list[i][0].tolist())
    for j in range(len(candidates[i])):
        string += decode(candidates[i][j][0].tolist())
    print(string)

JULIET:
O Romeo, Romeo! wherefore art thou R                                                                
JULIET:
O Romeo, Romeo! wherefore art thou R                                                                
[tensor([[ 1, 56]], device='mps:0'), tensor([[43,  1, 58, 53]], device='mps:0')]
[tensor([[ 1, 46]], device='mps:0'), tensor([[43,  1, 58, 53]], device='mps:0')]
JULIET:
O Romeo, Romeo! wherefore art thou R  re to
JULIET:
O Romeo, Romeo! wherefore art thou R   ae tr
JULIET:
O Romeo, Romeo! wherefore art thou R   r e to
JULIET:
O Romeo, Romeo! wherefore art thou R     ae to
JULIET:
O Romeo, Romeo! wherefore art thou R      ae ta
JULIET:
O Romeo, Romeo! wherefore art thou R       ae to
JULIET:
O Romeo, Romeo! wherefore art thou R        re to
JULIET:
O Romeo, Romeo! wherefore art thou R         ne to
JULIET:
O Romeo, Romeo! wherefore art thou R          re to
JULIET:
O Romeo, Romeo! wherefore art thou R           a teu
JULIET:
O Romeo, Romeo! wherefore art thou R           

In [130]:
len(candidates[0])

2

In [56]:
# sample a batch of data
xb, yb = get_batch('train')
    
# train
logits, loss, mlogits = model(xb, yb, verbose=True)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

logits:  torch.Size([8, 64, 65]) tensor([[[ 0.5667,  0.6185, -0.1166,  ..., -0.1144, -0.0944, -0.2431],
         [ 0.5633,  0.6154, -0.1280,  ..., -0.0966, -0.0988, -0.2639],
         [ 0.5496,  0.6065, -0.1368,  ..., -0.1071, -0.0946, -0.2643],
         ...,
         [ 0.5523,  0.6001, -0.0924,  ..., -0.1204, -0.0918, -0.2542],
         [ 0.5558,  0.6535, -0.1044,  ..., -0.1009, -0.0640, -0.2409],
         [ 0.5374,  0.5810, -0.1081,  ..., -0.1128, -0.1333, -0.2904]],

        [[ 0.5773,  0.6503, -0.1017,  ..., -0.1351, -0.1084, -0.2505],
         [ 0.5605,  0.6129, -0.1200,  ..., -0.0934, -0.0925, -0.2418],
         [ 0.5610,  0.6189, -0.0824,  ..., -0.1325, -0.0915, -0.2304],
         ...,
         [ 0.5596,  0.6150, -0.1221,  ..., -0.1657, -0.1189, -0.2371],
         [ 0.5911,  0.6233, -0.1136,  ..., -0.1333, -0.1020, -0.2338],
         [ 0.5854,  0.6214, -0.1573,  ..., -0.1324, -0.1047, -0.2477]],

        [[ 0.5618,  0.6575, -0.1159,  ..., -0.0903, -0.0980, -0.2379],
         [ 0

In [57]:
mlogits.shape

torch.Size([4, 8, 64, 65])

In [58]:
mlogits[:,:,-1,:].shape

torch.Size([4, 8, 65])

In [59]:
mlogits[...,-1,:].shape

torch.Size([4, 8, 65])

'Juliet:Where the the the the the the the the the the the the the the the the '