#### !!!! DO NOT RUN THIS FIRST CELL UNLESS YOU HAVE THE SAME VENV PATH ISSUE THAT I DO

In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/learning_medusa/venv/lib/python3.11/site-packages')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
import random

#### !!!! ONLY FOR APPLE SILICON

In [3]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [14]:
# hyperparameters
b = 3 # how many independent sequences will we process in parallel?
t = 8 # what is the maximum context length for predictions?
max_iters = 100
eval_interval = 10
lr = 3e-4 # learning rate for each backprop step
eval_iters = 20
d = 16 # embedding aka hidden dimension
h = 4 # number of attention heads
l = 4 # number of transormer layers
dropout = 0.2 # % of parameters to ignore every iteration
l2 = 0.01 # multiplier for our L2 norm to encourage sparsity

k = 4

In [15]:
# the dataset is TinyShakespeare
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [16]:
# here are all the unique characters that occur in this text
# we'll be using individual characters instead of tokens
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [17]:
# create a mapping from characters to integers & vice versa
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [18]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [19]:
def get_batch(split, k=k, b=b, t=t):
    # Assume train_data and val_data are defined outside this function
    data = train_data if split == 'train' else val_data
    max_index = len(data) - k * t
    ix = torch.randint(max_index, (b,))

    # Initialize x and y
    x = torch.zeros((k, b, t), dtype=data.dtype).to(device)
    y = torch.zeros((k, b, t), dtype=data.dtype).to(device)

    # Fill in x and y
    for i in range(k):
        x[i] = torch.stack([data[j+i*t:j+i*t+t] for j in ix])
        y[i] = torch.stack([data[j+i*t+1:j+i*t+t+1] for j in ix])

    return x, y

In [21]:
# so you can see what the tokenized data looks like
x,y = get_batch('train')
print("x ", x.shape, "\n", x)
print("y ", y.shape, "\n", y)

x  torch.Size([4, 3, 8]) 
 tensor([[[ 1, 43, 39, 57, 43,  1, 51, 63],
         [ 1, 45, 59, 39, 56, 42,  6,  0],
         [61, 47, 58, 46, 47, 52,  2,  1]],

        [[ 1, 46, 43, 39, 56, 58,  8,  0],
         [35, 43,  1, 51, 39, 63,  1, 57],
         [24, 47, 44, 58,  1, 59, 54,  1]],

        [[32, 46, 43,  1, 57, 47, 45, 46],
         [59, 56, 54, 56, 47, 57, 43,  1],
         [58, 46, 63,  1, 50, 53, 53, 49]],

        [[58,  1, 53, 44,  1, 39, 52, 63],
         [39, 52, 42,  1, 58, 39, 49, 43],
         [57, 10,  0, 18, 56, 53, 51,  1]]], device='mps:0')
y  torch.Size([4, 3, 8]) 
 tensor([[[43, 39, 57, 43,  1, 51, 63,  1],
         [45, 59, 39, 56, 42,  6,  0, 35],
         [47, 58, 46, 47, 52,  2,  1, 24]],

        [[46, 43, 39, 56, 58,  8,  0, 32],
         [43,  1, 51, 39, 63,  1, 57, 59],
         [47, 44, 58,  1, 59, 54,  1, 58]],

        [[46, 43,  1, 57, 47, 45, 46, 58],
         [56, 54, 56, 47, 57, 43,  1, 39],
         [46, 63,  1, 50, 53, 53, 49, 57]],

        [[ 1,

In [27]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, c_vecs, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [23]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4 * d), # the 4 is arbitrary, but i wouldn't go smaller
            nn.ReLU(), 
            nn.Linear(4 * d, d),
            nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

In [24]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d, head_size, bias=False)
        self.query = nn.Linear(d, head_size, bias=False)
        self.value = nn.Linear(d, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(t, t))) # mask future timestesps
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        b,t,d = x.shape
        k = self.key(x)   # (b,t,d/h)
        q = self.query(x) # (b,t,d/h)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (b, t, d/h) @ (b, d/h, t) -> (b, t, t)
        wei = wei.masked_fill(self.tril[:t, :t] == 0, float('-inf')) # (b, t, t)
        wei = F.softmax(wei, dim=-1) # (b, t, t)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (b,t,d/h)
        out = wei @ v # (b, t, t) @ (b, t, d/h) -> (b, t, d/h)
        return out

In [25]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(h)])
        self.proj = nn.Linear(head_size * h, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [26]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d, h):
        # d: embedding dimension, h: the number of heads we'd like
        super().__init__()
        head_size = d // h # the double backslash just makes the output an int instead of float
        self.sa = MultiHeadAttention(h, head_size)
        self.ffwd = FeedFoward(d)
        self.ln1 = nn.LayerNorm(d)
        self.ln2 = nn.LayerNorm(d)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [17]:
class conceptGPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, d)
        
        # simple learned positional encodings rather than sine or RoPE
        self.position_embedding_table = nn.Embedding(t, d) 
        self.blocks = nn.Sequential(*[Block(d, h) for _ in range(l)]) # bulk of the beast
        self.ln_f = nn.LayerNorm(d) # final layer norm
        self.lm_head = nn.Linear(d, v)
        
        self.conc_head = FeedForward(d)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, c, targets=None, c_hat=None, verbose=False):
        b, t = idx.shape

        # idx and targets are both (b,t) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        x = tok_emb + pos_emb # (b,t,d)
        x = self.ln_f(self.blocks(x)) # (b,t,d) -> (b,t,d)
        
        # the regular next-token prediction head
        logits = self.lm_head(x) # (b,t,d)@(d,v)=(b,t,v)
        
        c_vecs = self.conc_head(x)[:,0,:] # (b,d)
        
        if targets is None:
            # if we're not training at all we can ignore loss
            loss = None
        elif c_hat: 
            # if we've been given a c_hat, aka we're NOT on the end of the sequence
            
            # regular NTP loss
            b, t, v = logits.shape
            logits = logits.view(b*t, v)
            targets = targets.view(b*t)
            loss = F.cross_entropy(logits, targets)
            
            # cosine similarity loss for concept vector
            cosine_similarity = nn.CosineSimilarity(dim=1)
            similarity = cosine_similarity(c_vecs, c_hat)
            closs = 1 - similarity.mean()  # Maximizing cosine similarity is equivalent to minimizing 1 - cosine similarity
            
            # this will likely need a parameter to balance the two
            loss = loss + closs
        else:
            # if we're on the first run of training, aka don't have a c_hat
            b, t, v = logits.shape
            logits = logits.view(b*t, v)
            targets = targets.view(b*t)
            loss = F.cross_entropy(logits, targets)

        return logits, c_vecs, loss
    
    ##### i've definitely got something here but ofc a later version needs to
    # work over every timestep rather than only going into effect at the end of
    # the context length, and maybe it should also incorporate this dynamic effect
    # with the gamma neighborhoods that would be used for inference but idk how
    # that would work yet. but i'm pretty convinced this is along the lines of
    # what i've been imagining in my head so often. it feels like there's so many
    # different ways i could go with it to implement into so many of the dispirate ideas

    def generate_gpt(self, idx, max_new_tokens, temperature=1.0):
        # idx is (b, t) array of indices in the current context
        #assert temperature >= 0
        
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -t:]
            
            # get the predictions
            logits, loss, c_vecs = self(idx_cond)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            
            # scale logits by the temperature
            logits = logits / (temperature+1e-10)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        
        return idx

# Training

if you don't want to do your own training just scroll down

In [19]:
model = medusaGPT().to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

1740.934 K parameters


In [20]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)

In [21]:
start_time = time.time()
for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    # train
    logits, loss, mlogits = model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

step 0: train loss 14.9370, val loss 14.9544, time elapsed: 1.24 seconds
step 100: train loss 11.3176, val loss 11.3019, time elapsed: 36.43 seconds
step 200: train loss 10.9342, val loss 10.9574, time elapsed: 64.07 seconds
step 300: train loss 10.7498, val loss 10.7684, time elapsed: 91.07 seconds
step 400: train loss 10.6352, val loss 10.6123, time elapsed: 118.19 seconds
step 500: train loss 10.4564, val loss 10.4729, time elapsed: 145.52 seconds
step 600: train loss 10.3210, val loss 10.3472, time elapsed: 173.01 seconds
step 700: train loss 10.1753, val loss 10.2610, time elapsed: 200.49 seconds
step 800: train loss 10.0691, val loss 10.1848, time elapsed: 227.94 seconds
step 900: train loss 9.9825, val loss 10.0586, time elapsed: 255.39 seconds
step 1000: train loss 9.8192, val loss 9.9655, time elapsed: 282.88 seconds
step 1100: train loss 9.7813, val loss 9.9270, time elapsed: 310.43 seconds
step 1200: train loss 9.6700, val loss 9.8921, time elapsed: 339.36 seconds
step 1300:

## save the trained model

In [22]:
torch.save(model.state_dict(), f'models/medusa_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2-{l2}_m{m}_mdiscount{medusa_discount:.2f}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Load a saved model

In [21]:
model = medusaGPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/medusa_b24_t128_d128_h8_l8_lr0.0003_drop0.2_l2-0.01_m5_mdiscount0.80_2024-01-25|23-31-12.pth'))
# that's the better model of the two that I trained. The extra heads were useless tho

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

medusaGPT(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(128, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
   

## Regular GPT Inference

So it seems to me like regular GPT inference, although not the fastest, gives the best output quality. You'll see why my versions have worse output quality later

In [22]:
%%time # to keep track of how long it takes
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output = model.generate_gpt(context_tensor, max_new_tokens=218)
output_str = decode(output[0].tolist())
print(output_str)

JULIET:
O Romeo, Romeo! wherefore art thou Richi? lips, Citizen?

ARCHTGGS:
Here's here: I'll forth his son in those fear;
Lord and far his sight, and not be kneel,
Which shad my many brotherful sea,
And a word a ben the people.

MENENIUS:
Upon my name, sir: I h
CPU times: user 32.7 s, sys: 1.71 s, total: 34.4 s
Wall time: 34.1 s


# Medusa's first sister Stheno (the aggressive one)

ChatGPT said:
Her name translates to "strength" or "forceful". Stheno was the eldest and most fierce of the sisters, known for her strength and ferocity. 

So I guess since this generation is strictly doing greedy decoding i'll name it after her

The idea with Stheno is we're just imementing a tiny foundational concept of Medusa. Basically, while next-token prediction does prediction in parallel not only across batches but also (more importantly here and less frequently talked about) across time-steps in a sequence like this:
<table border="1">
    <thead>
        <tr>
            <th>input (<code>prompt</code>)</th>
            <td>'R'</td>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td>'o'</td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>NTP output</th>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td>'o'</td>
            <td>'?'</td>
        </tr>
    </tbody>
</table>

Our medusa heads will make output predictions multiple steps into the future that look like this
<table border="1">
    <thead>
        <tr>
            <th>input (<code>prompt</code>)</th>
            <td>'R'</td>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td>'o'</td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>medusa head 1</th>
            <td>'m'</td>
            <td>'e'</td>
            <td>'o'</td>
            <td>'?'</td>
            <td>'\n'</td>
        </tr>
    </tbody>
</table>
You can imagine what it looks like for further medusa heads.

The question now is how can we take advantage of this to speed up inference? Basically we take the input text and the output from each attention head and run those into the model, so the steps assuming two medusa heads are
1) input prompt 'R' and recieve output
<table border="1">
    <thead>
        <tr>
            <th>input</th>
            <td>'R'</td>
            <td></td>
            <td></td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>NTP output</th>
            <td>'o'</td>
            <td></td>
            <td></td>
        </tr>
        <tb>
            <th>medusa head 1</th>
            <td></td>
            <td>'m'</td>
            <td></td>
        </tb>
    </tbody>
    <tbody>
        <tb>
            <th>medusa head 2</th>
            <td></td>
            <td></td>
            <td>'e'</td>
        </tb>
    </tbody>
</table>
2) input sequence 'Rome' and recieve output
<table border="1">
    <thead>
        <tr>
            <th>input</th>
            <td>'R'</td>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td></td>
            <td></td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>NTP output</th>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td>'o'</td>
            <td></td>
            <td></td>
        </tr>
        <tb>
            <th>medusa head 1</th>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td>'?'</td>
            <td></td>
        </tb>
    </tbody>
    <tbody>
        <tb>
            <th>medusa head 2</th>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td>'\n'</td>
        </tb>
    </tbody>
</table>

3) Notice how our NTP output has now confirmed for us that medusa heads 1 and 2 were correct because the NTP row matches up with the input. Now, even though we've only run the model twice, we've obtained 4 tokens 'omeo' that are confirmed to be predicted by the regular NTP algorithm.


However, what if the medusa heads hadn't been right? Let's rewind
1) input prompt 'R' and recieve output
<table border="1">
    <thead>
        <tr>
            <th>input</th>
            <td>'R'</td>
            <td></td>
            <td></td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>NTP output</th>
            <td>'o'</td>
            <td></td>
            <td></td>
        </tr>
        <tb>
            <th>medusa head 1</th>
            <td></td>
            <td>'m'</td>
            <td></td>
        </tb>
    </tbody>
    <tbody>
        <tb>
            <th>medusa head 2</th>
            <td></td>
            <td></td>
            <td>'u'</td>
        </tb>
    </tbody>
</table>
3) Maybe medusa head 2 thinks we're going to predict the name 'Romulus'? Let's input this string 'Romu' and see what we get
<table border="1">
    <thead>
        <tr>
            <th>input</th>
            <td>'R'</td>
            <td>'o'</td>
            <td>'m'</td>
            <td>'u'</td>
            <td></td>
            <td></td>
        </tr>
    </thead>
    <tbody>
        <tr>
            <th>NTP output</th>
            <td>'o'</td>
            <td>'m'</td>
            <td>'e'</td>
            <td>'l'</td>
            <td></td>
            <td></td>
        </tr>
        <tb>
            <th>medusa head 1</th>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td>'u'</td>
            <td></td>
        </tb>
    </tbody>
    <tbody>
        <tb>
            <th>medusa head 2</th>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td></td>
            <td>'s'</td>
        </tb>
    </tbody>
</table>
3) Notice how the NTP output gives us an 'e' instead of the expected 'u'. It's unfortunate that medusa head 2's prediction was wrong, but medusa head 1's prediction was correct so we're still correct and this "checking" step has still given us one output 'e' meaning we're still ahead of regular next-token prediction. So far we've run the model twice but received 3 tokens, overall still putting us ahead. For the next step we'll just input the confirmed tokens we've recieved so far, which are 'Rome'

A key insight to this process is that we're considering NTP's output to be the "ground truth" during inference, and we get to use it as confirmation because of how transformers are trained in parallel across steps in the sequence. Thanks to this, even when every single medusa head is wrong we still get 1 token out of the NTP head. If our medusa heads are trained extremely poorly then this process should reduce down to NTP plus the marginal extra compute from the heads & checking, but if the heads are trained well then we can get a pretty significant speed increase

Now let's give it a try implementing this algorithm which I'll call Stheno, the ugliest sister

##### NOTE: I frequently refer to 't' throughout the remainder of the code. keep in mind t does not remain static, it grows over time. If it gets confusing just ignore any references to t and know that it could mean t give or take 1+m

In [4]:
def generate_Stheno(model, idx, max_runs):
    # Ensure idx is a single sequence. This won't work for batched inference
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    print("idx: ", idx.shape)
    
    # crop idx to the last block_size tokens jic it's too long
    input_cond = idx[:, -t:] 
    print("idx_cond should be size (1, t): ", input_cond.shape)
          
    # get the predictions
    logits, loss, mlogits = model(input_cond) # (1,t,v), int, and (m,1,t,v)
    print("logits should be (1,t,v): ", logits.shape)
    print("mlogits should be (m,1,t,v): ", mlogits.shape)
    
    # we only want the medusa preditions for the newest future timestep
    mlogits = mlogits[...,-1,:] # becomes (m,1,v)
    mlogits = mlogits.squeeze(dim=1) # (m,v)
    print("mlogits should be (m,v): ", mlogits.shape)
    
    # actual medusa does a whole proabability attention mask tree thing, 
    #   but here we'll assume greedy decoding on snake heads
    idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (1,m)
    print("idx_m_prev should be (1,m): ", idx_m_prev.shape)
    # we name it _prev since it'll be used on the next loop
    
    # medusa requires greedy decoding for the first regular t+1'th token
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1, t)
    print("idx_ntp should be (1, t): ", idx_ntp.shape)
    
    # append sampled index to the running sequence
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) # (1,t)+(1,1) -> (1, t+1)
    print("idx should be (1, t+1): ", idx.shape)

    # keep track of how many tokens we get per model inference
    tok_per_inf = [1]
    
    for _ in range(max_runs-1): # -1 since one iteration was done above
        ###### I think this is where i begin the looped part. above stuff should be before the for loop
        
        # so now we have idx shape (1,t+1) of context tokens plus 1 predicted token
        # and idx_m_prev shape  (1,m) of speculative tokens to check against later & include some
        # want to know whether model actually would've predicted the tokens in idx_m so we need to run it with them & compare
        
        # first we construct a tensor composed of the initial context, the single ntp token, and then the speculative tokens
        inp = torch.cat((idx, idx_m_prev), dim=1) # (1,t+1) & (1,m) -> (1,t+1+m)
        print("input should be (1,t+1+m): ", inp.shape)
        
        # but since the model can't take more than t inputs
        # note: t is max content limit, but input might be shorter. if so that's fine
        input_cond = inp[:, -t:] 
        # (1,t+1+m) -> (1,t) where first t-1-m are prior context, the t-m'th is ntp, & last m are candidates
        print("input_cond should be (1,t): ", input_cond.shape)
        
        # then we pass it in
        logits, loss, mlogits = model(input_cond) # (1,t,v), tensor of a single float, & (m,1,t,v)
        
        # we're just greedy decoding so there's no need to softmax
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1,t,v) -> (1, t)
        print("idx_ntp should be (1, t): ", idx_ntp.shape)
        
        # now for making the comparison we need to ignore the indices that were part of the context
        # the regular ntp prediction can also be ignored
        # we just need to know the number of previous turn's medusa preds to keep
        idx_check = idx_ntp[:,-(m+1):-1] # (1,t) -> (1,m)
        print("idx_check should be (1,m): ", idx_check.shape, "\n", idx_check)
        
        # let's start checking whether they match
        match_tensor = (idx_m_prev == idx_check).int()
        #(1,m) of ints & (1,m) of ints -> (1,m) of 1's and 0's
        print("match_tensor should be (1,m) of 1's and 0's: ", match_tensor.shape, "\n", match_tensor)
        
        # then define our "definitely good" indices as those which are 1's for the similarity check
        # We invert the tensor so that all zeros & 1's flip
        # Find the first 1 in each row. The max function returns the index of the first occurrence of the maximum value.
        # We add one extra zero at the end of each row to handle rows that contain no zeros.
        pad = torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype, device=device)
        print("pad: ", pad.shape, "\n", pad)
        padded_tensor = torch.cat((1 - match_tensor, pad), dim=1)
        print("padded_tensor: ", padded_tensor.shape, "\n", padded_tensor)
        zero_positions = padded_tensor.argmax(dim=1)
        
        # Adjust indices where the last position is selected (meaning there were no zeros)
        zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)
        print("zero_positions: ", zero_positions.shape, "\n", zero_positions)
        
        # Create a range tensor
        range_tensor = torch.arange(match_tensor.size(1), device=device).unsqueeze(0).expand_as(match_tensor)
        print("range_tensor: ", range_tensor.shape, "\n", range_tensor)
        
        # Create a mask where each element is 1 if it is before the first zero in its row
        mask = range_tensor < zero_positions.unsqueeze(1)
        print("mask: ", mask.shape, "\n", mask)
        
        # Apply the mask and sum along each row
        result = (match_tensor * mask).sum(dim=1).item()
        print("result: ", result)

        tok_per_inf.append(result+1)
        
        # so now i've got result which is an integer between 0 and m (inclusive)
        # i need to use it as an index on the actual medusa head outputs of interest and then append that to our sequence
        # but we also get to use the regular ntp prediction right after any of last run's accepted medusa head predictions
        print("idx: ", idx.shape, "\n", idx)
        idx_m_prev = idx_m_prev[:,:result]
        print("idx_m_prev should be (1,result): ", idx_m_prev.shape, "\n", idx_m_prev)
        idx_ntp = idx_ntp[:,-1-m+result].unsqueeze(dim=0)
        print("idx_ntp should be (1,1): ", idx_ntp.shape, "\n", idx_ntp)
        idx = torch.cat((idx, idx_m_prev, idx_ntp),dim=1)
        print("idx should be (1,t+result+1: ", idx.shape, "\n", idx)
        
        # so now we have idx of size (1,t+1+result)
        # and idx_m of size (1,m)
        # wait but there's a problem. If we didn't use the largest result=m then 
        #   idx_m is for items too far into the future. we need to pass forward an idx_m that's relevant
        # ngl i'm too tired to think about it so i'm assuming i can use the same index i used for idx_ntp
        mlogits = mlogits[...,-1-m+result,:].squeeze(dim=1) # (m,1,t,v) -> (m,1,v)
        print("mlogits: ", mlogits.shape)
        idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (m,1,v) -> (1,m)
        print("idx_m_prev: ", idx_m_prev.shape, "\n", idx_m_prev)
        # we name it _prev since it'll be used on the next loop
        
    return idx, tok_per_inf

In [24]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Stheno(model, context_tensor, max_runs=100)
output_str = decode(output[0].tolist())

idx:  torch.Size([1, 44])
idx_cond should be size (1, t):  torch.Size([1, 44])
logits should be (1,t,v):  torch.Size([1, 44, 65])
mlogits should be (m,1,t,v):  torch.Size([5, 1, 44, 65])
mlogits should be (m,v):  torch.Size([5, 65])
idx_m_prev should be (1,m):  torch.Size([1, 5])
idx_ntp should be (1, t):  torch.Size([1, 44])
idx should be (1, t+1):  torch.Size([1, 45])
input should be (1,t+1+m):  torch.Size([1, 50])
input_cond should be (1,t):  torch.Size([1, 50])
idx_ntp should be (1, t):  torch.Size([1, 50])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[51, 43, 12, 12,  0]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 1, 1]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 1, 0, 0, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1,

idx_ntp should be (1, t):  torch.Size([1, 67])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[58,  1, 47, 47, 44]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True,  True, False, False, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 62]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39]], device='mps:

idx should be (1,t+result+1:  torch.Size([1, 77]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12]], device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[0, 0, 1, 1, 1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 82])
input_cond should be (1,t):  torch.Size([1, 82])
idx_ntp should be (1, t):  torch.Size([1, 82])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[ 0,  0, 15, 15, 24]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) te

idx should be (1,t+result+1:  torch.Size([1, 91]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61]], device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[39, 39, 43,  1,  1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 96])
input_cond should be (1,t):  torch.Size([1, 96])
idx_ntp should be (1, t):  torch.Size([1, 96])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[47, 57, 56,  1, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[0, 0, 0, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], de

idx_ntp should be (1, t):  torch.Size([1, 106])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[ 1, 58, 46,  1, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 1, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True,  True, False, False, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 101]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx_ntp should be (1, t):  torch.Size([1, 114])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[12, 12, 56, 12, 53]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[0, 0, 1, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[1, 1, 0, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[False, False, False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 109]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx should be (1,t+result+1:  torch.Size([1, 121]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0]], device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[35, 46,  1,  1,  1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 126])
input_cond should be (1,t):  torch.Size([1, 126])
idx_ntp should be (1, t):  torch.Size([1, 126])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[21, 46, 39, 61, 61]], device='mps:0')
match_tensor should be (1,m) of 1'

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[50, 50,  1, 52, 52]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 1, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 0, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([3], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True,  True,  True, False, False]], device='mps:0')
result:  3
idx:  torch.Size([1, 125]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[46, 43, 57, 58, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 0, 0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 1, 1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([1], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True, False, False, False, False]], device='mps:0')
result:  1
idx:  torch.Size([1, 137]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[ 1,  1, 56,  1, 53]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[0, 0, 1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[1, 1, 0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[False, False, False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 143]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[39, 56, 43,  1, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[0, 1, 1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[1, 0, 0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[False, False, False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 152]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx should be (1,t+result+1:  torch.Size([1, 158]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39]],
       device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[58,  1,  0, 53, 43]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 163])
input_cond should be (1,t):  torch.Size([1, 128])
idx

idx should be (1,t+result+1:  torch.Size([1, 166]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57]], device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[39, 60, 43,  0,  0]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 171])
input_cond should b

idx should be (1,t+result+1:  torch.Size([1, 172]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40]], device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[43, 43, 43,  0,  0]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 1

idx should be (1,t+result+1:  torch.Size([1, 180]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24]],
       device='mps:0')
mlogits:  torch.Size([5, 65])
idx_m_prev:  torch.Size([1, 5]) tensor([[33, 30, 17, 17, 10]], device='mps:0')
input

input should be (1,t+1+m):  torch.Size([1, 194])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[47, 56, 56,  1, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[0, 0, 0, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[1, 1, 1, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[False, False, False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 189]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1

idx should be (1,t+result+1:  torch.Size([1, 198]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57]],
       device='mps:0')
mlogits:  torch.Size([5, 65])
idx

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[43,  1, 58, 58, 44]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True,  True, False, False, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 206]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 

idx:  torch.Size([1, 212]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 39, 52]],
       device='

idx should be (1,t+result+1:  torch.Size([1, 221]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

idx_m_prev:  torch.Size([1, 5]) tensor([[52, 45,  1,  1,  1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 231])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 5]) tensor([[52, 45, 58, 58, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 5]) tensor([[1, 1, 0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 6]) tensor([[0, 0, 1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 5]) tensor([[0, 1, 2, 3, 4]], device='mps:0')
mask:  torch.Size([1, 5]) tensor([[ True,  True, False, False, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 226]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 

idx should be (1,t+result+1:  torch.Size([1, 237]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

idx should be (1,t+result+1:  torch.Size([1, 240]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

idx should be (1,t+result+1:  torch.Size([1, 249]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

idx should be (1,t+result+1:  torch.Size([1, 256]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

idx should be (1,t+result+1:  torch.Size([1, 259]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0, 28, 56, 53, 60,
         53, 57, 58, 10,  0, 35, 46, 39, 58,  1, 47, 57,  1, 58, 46, 43,  1, 61,
         53, 56, 50, 42, 12,  0,  0, 15, 13, 25, 21, 24, 24, 27, 10,  0, 21,  1,
         61, 47, 50, 50,  1, 52, 53, 58,  1, 40, 43,  1, 58, 46, 43,  1, 51, 39,
         52, 12,  0,  0, 24, 17, 27, 26, 32, 17, 31, 10,  0, 21,  1, 61, 47, 50,
         50,  1, 52, 53, 58,  1, 57, 43, 43,  1, 58, 46, 43,  1, 51, 39, 52,  1,
         53, 44,  1, 58, 46, 43,  1, 51, 39, 52,  1, 58, 46, 39, 58,  1, 58, 46,
         43, 63,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 57, 53,  8,  0,  0, 24,
         33, 15, 21, 27, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 57,
         53,  1, 57, 53,  1, 51, 53, 56, 43,  1, 58, 46, 3

In [25]:
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print(output_str)

tokens per inference:  2.16
JULIET:
O Romeo, Romeo! wherefore art thou Rome?

Provost:
What is the world?

CAMILLO:
I will not be the man?

LEONTES:
I will not see the man of the man that they shall be so.

LUCIO:
I will not so so more than than the strength of the strength,
And the stre


#### notice that Stheno is restricted to greedy decoding which means the output quality is lower. If you compare the two passages you'll see that the NTP one has far less of a problem with repetition

---

# Medusa's second sister Euryale (the explorative one)


ChatGPT said

Her name means "far-roaming" in Greek. Euryale was known for her loud crying or bellowing. If your architecture is meant to explore a wide range of possibilities or to "roam" extensively through a dataset, the name Euryale might be suitable. Additionally, if your architecture involves a broad or far-reaching search strategy or is notable for 'broadcasting' its findings extensively (analogous to loud crying), Euryale could be an apt choice.



The goal here with Euryale is effectively to bring probabalistic decoding back to Stheno. Medusa does this using this whole attention thing, and their approach is certainly better. I created Euryale as a way to explore an alternative method I thought of that aims at doing largely the same thing but by taking advantage of batched inference rather than an attention mechanism. My hope is that this will only require more ram and not result in any added latency, but we'll see. 

The basic idea is to 
1) use topk instead of argmax
2) construct every single possible candidate sequence from the topk results of the medusa heads 
3) run every single one of those possible candidates through the model using batched inference
4) instead of checking against a greedy version of the NTP's output, check against the NTP's topk results
5) select whichever successful candidate is the longest. If there are multiple acceptable candidates, we now have to choose whether we want the one that was most probable, least probable, or somewhere in-between. There was no really clear way to do this; you can see what I settled on below

If anything i think the re-incorporation of topk results will *maybe* add a speed increase if they give us longer accepted sequences (although likely not as fast as the attention-based mechanism used in actual Medusa). However the real bottleneck I'm adding is memory. We're literally performing batched inference as opposed to single, and the check mechanism is also much larger. I'm not sure exactly how to do the calculation but I think the batched inference results in a linear increase in memory consumption on what was already the largest chunk of our computation, and then the topk matching results in a maybe exponential increase in memory usage but only with reference to a part that was already relatively small. Idk you decide,  i'm too tired at this point

In [26]:
# setup functions to make the actual loop less ugly

def combinations(tensor):
    '''
    takes our topk results and creates every possible permutation of them
    '''
    # Get the shape of the tensor
    m, k = tensor.shape

    # Create index grids for each dimension
    index_grids = [torch.arange(k) for _ in range(m)]

    # Create a meshgrid of indices with reverse order
    mesh_indices = torch.meshgrid(index_grids[::-1], indexing="ij")

    # Use advanced indexing to create the combination tensor
    # Reverse the order of tensor indexing to match the reversed mesh_indices
    combinations = torch.cat([tensor[m-1-i][mesh_indices[i]].unsqueeze(0) for i in range(m)], dim=0)

    # Reshape to the desired shape (k^m, m) and reverse the columns for the final output
    combinations = combinations.T.reshape(-1, m)
    combinations = combinations.flip(dims=[1])

    return combinations

def compare(A,B):
    '''
    compares our two ugly mess of candidate possibilities & checks. 
    Outputs a matrix that can be used to find the indices of interest
    '''
    i,j,k = A.shape
    
    match_tensor = (A == B).int() 
    print("match_tensor: ", match_tensor.shape, "\n", match_tensor)

    # We invert the tensor so that all zeros & 1's flip
    # Find the first 1 in each row. The max function returns the first occurrence of the maximum value.
    # We add one extra zero at the end of each row to handle rows that contain no zeros.
    pad = torch.ones((i,j,1), dtype=match_tensor.dtype, device=device)
    print("pad: ", pad.shape)#, "\n", pad)
    padded_tensor = torch.cat((1 - match_tensor, pad), dim=-1)
    print("padded_tensor: ", padded_tensor.shape, "\n", padded_tensor)
    zero_positions = padded_tensor.argmax(dim=-1)
    print("zero_positions: ", zero_positions.shape, "\n", zero_positions)

    # Adjust indices where the last position is selected (meaning there were no zeros)
    zero_positions[zero_positions >= k] = k
    print("zero_positions: ", zero_positions.shape, "\n", zero_positions)

    # Create a range tensor
    range_tensor = torch.arange(m, device=device).unsqueeze(0).expand_as(match_tensor)
    print("range_tensor: ", range_tensor.shape, "\n", range_tensor)

    # Create a mask where each element is 1 if it is before the first zero in its row
    mask = range_tensor < zero_positions.unsqueeze(-1)
    print("mask: ", mask.shape, "\n", mask)

    # Apply the mask and sum along each row
    result = (match_tensor * mask).sum(dim=-1)
    #print("result: ", result.shape)#, "\n", result)
    
    return result

In [27]:
def generate_Euryale(model, idx, max_runs, k=2):
    # Ensure idx is a single sequence
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    print("idx: ", idx.shape)
    
    # crop idx to the last block_size tokens
    input_cond = idx[:, -t:] # should this be just t???? previosly i had it as -(t+m)
    print("idx_cond should be size (1, t): ", input_cond.shape)
    
    # get the predictions
    logits, loss, mlogits = model(input_cond) # (1,t,v), int, and (m,1,t,v)
    print("logits should be (1,t,v): ", logits.shape)
    print("mlogits should be (m,1,t,v): ", mlogits.shape)
    
    # medusa requires greedy decoding for the first regular t+1'th token
    # technically we could just do regular temperature probability sampling here BUT
    #   the problem is that would shift the distribution from what the medusa heads expect
    #   thereby giving us far worse results
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1, t)
    print("idx_ntp should be (1, t): ", idx_ntp.shape)
    
    # append sampled index to the running sequence
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) # (1,t)+(1,1) -> (1, t+1)
    print("idx should be (1, t+1): ", idx.shape)

    # keep track of how many tokens we get per model inference
    tok_per_inf = [1]
    
    # we only want the medusa preditions for the newest future timestep
    mlogits = mlogits[...,-1,:] # becomes (m,1,v)
    mlogits = mlogits.squeeze(dim=1) # (m,v)
    print("mlogits should be (m,v): ", mlogits.shape)
    
    for _ in range(max_runs-1): # -1 since one iteration was done above
        
        # get the medusa heads' best guesses
        idx_m_topk = torch.topk(mlogits, k, dim=-1, largest=True).indices # (m,1,v) -> (m,k)
        print("idx_m_topk should be (m,k): ", idx_m_topk.shape)
        
        # so now we have idx shape (1,t+1) of confirmed tokens
        # and idx_m_topk shape (m,k) of topk speculative tokens from medusa heads
        # want to know whether model actually would've predicted any of the tokens 
        #   in idx_m so we need to run it with them & compare
        
        # get all the possible combinations of the topk results from each snake head
        mcomb = combinations(idx_m_topk) # (m, k) -> (k^m,m) of all k^m combinations
        print("mcomb should be (k^m,m): ", mcomb.shape, "\n", mcomb)
        
        # expanding out input sequence to take advantage of batched inference
        idx_rep = idx.repeat(k**m,1) # (1,t+1) -> (k^m, t+1)
        print("idx_rep should be (k^m,t+1): ", idx_rep.shape)
        
        # we construct a tensor composed of the initial context and then the speculative tokens
        # now this k^m dimension is effecitvely our batch size b as if we were doing batched inference
        inp = torch.cat((idx_rep, mcomb), dim=1) # (k^m,t+1) & (k^m,m) -> (k^m,t+1+m)
        print("input should be (k^m,t+1+m): ", inp.shape)
        
        # but since the model can't take more than t inputs
        # note: t is max content limit, but input might be shorter. if so that's fine
        input_cond = inp[:, -t:] 
        # (k^m,t+1+m) -> (k^m,t) where first t-1-m are prior context, the t-m'th is ntp, & last m are candidates
        print("input_cond should be (k^m,t): ", input_cond.shape, "\n", input_cond[:,-2*m:])
        
        # then we pass it in
        logits, loss, mlogits = model(input_cond) # (k^m,t,v), tensor of a single float, & (m,k^m,t,v)
        print("logits should be (k^m,t,v): ", logits.shape)
        print("mlogits should be (m,k^m,t,v): ", mlogits.shape)
        
        # instead of greedy or softmax & temperature, we grab topk values to be compared against mcomb later
        idx_ntp_topk = torch.topk(logits, k, dim=-1, largest=True).indices # (k^m,t,v) -> (k^m,t,k)
        print("idx_ntp_topk should be (k^m, t,k): ", idx_ntp_topk.shape, "\n", idx_ntp[:,-2*m:])
        
        # now for comparison's sake we need to ignore the indices that were part of the context
        # the regular ntp prediction can also be ignored
        # we just need to know the number of previous turn's medusa preds to keep
        idx_check = idx_ntp_topk[:,-(m+1):-1,:] # (k^m,t,k) -> (k^m,m,k)
        print("idx_check should be (k^m,m,k): ", idx_check.shape, "\n", idx_check)
        
        # get every possible combination to be checked against
        idx_check_comb = torch.stack([combinations(idx_check[i]) for i in range(idx_check.shape[0])])
        print("idx_check_comb should be (k^m,k^m,m): ", idx_check_comb.shape, "\n", idx_check_comb)
        
        # duplicate them so they can be compared to idx_check_comb
        mcomb_check = mcomb.unsqueeze(0).repeat(k**m,1,1) # (k^m,m) -> (1,k^m,m) -> (k^m,k^m,m)
        print("mcomb_check should be (k^m,k^m,m): ", mcomb_check.shape, "\n", mcomb_check)

        # check whether they match
        result = compare(mcomb_check,idx_check_comb)
        print("result:\n", result)
        
        # this is how many tokens our medusa head gave us
        max_val = torch.max(result).item()
        # keeping track
        tok_per_inf.append(max_val+1) # +1 for the extra NTP greedy token
        
        # however, there may be multiple possible candidates that give that number of tokens
        # the below rows look through 'result' and grab either the most or least likely rows or columns
        # this index flip works bc argmax always returns the index of the FIRST max value if there are multiple max values
        if random.choice([True, False]):
            max_idx_row = torch.max(result,1).indices[0].item() # most likely
            max_idx_col = len(result[max_idx_row]) - 1 - torch.argmax(result[max_idx_row].flip(0)).item() # least likely
        elif random.choice([True, False]):
            max_idx_row = (len(result) - 1 - torch.max(result.flip(dims=[0]),1).indices)[0].item() # least likely
            max_idx_col = torch.argmax(result[max_idx_row]).item() # most likely
        else:
            max_idx_row = torch.max(result,1).indices[0].item() # most likely
            max_idx_col = torch.argmax(result[max_idx_row]).item() # most likely
        print("max_val: ", max_val)
        print("max_idx_row: ", max_idx_row)
        print("max_idx_col: ", max_idx_col)
        # notice how those were not every possible combination, they're just the combination of probabilites I settled on
        # the problem with my approach is that if there are a lot of possibilities, only 3 of the options are made available
        # a more sophisticated approach would randomly choose between all the potential maximum candidates
        # actually now that I think about it, doing it that way might also be more computationally efficient
        # ugh i'm too tired tho idc anymore
        
        # so now i've got max_val which is an integer between 0 and m (inclusive)
        # and max_idx_row & max_idx_col which are indices of the best sequence in the k^m batch 
        # need to use them to get the actual medusa head outputs of interest and then append that to our sequence
        # bonus: we also get to use the regular ntp prediction right after any of last run's accepted medusa head predictions
        print("idx: ", idx.shape, idx)
        
        # now we splice out the chosen candidate sequence
        idx_m = mcomb[max_idx_row, :max_val].unsqueeze(0) # (k^m,m) -> (1,max_val)
        print("idx_m should be (1,max_val): ", idx_m.shape, "\n", idx_m)
        
        # and we also grab the free greedy token we get
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (k^m,t,v) -> (k^m, t)
        print("idx_ntp should be (k^m,t): ", idx_ntp.shape, "\n", idx_ntp)
        idx_ntp = idx_ntp[max_idx_row, -1-m+max_val].unsqueeze(0).unsqueeze(0) # (k^m, t) -> (1,1)
        print("idx_ntp should be (1,1): ", idx_ntp.shape, "\n", idx_ntp)
        
        # and here we add them all together
        idx = torch.cat((idx, idx_m, idx_ntp),dim=1) # (1,t), (1,1), (1,max_val) -> (1,t+1+max_val)
        print("idx should be (1,t+max_val+1): ", idx.shape, "\n", idx)
        
        # wait but there's a problem. If we didn't use the largest max_val=m then 
        #     idx_m_topk is for items too far into the future. we need to pass forward 
        #     an idx_m_topk that's relevant
        # ngl i'm too tired to think about it so i'm assuming i can use the same index i used for idx_ntp
        mlogits = mlogits[:,max_idx_row,-1-m+max_val,:] # (m,k^m,t,v) -> (m,v)
        print("mlogits should be (m,v): ", mlogits.shape)
         
    return idx, tok_per_inf

In [28]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Euryale(model, context_tensor, max_runs=100, k=2)

idx:  torch.Size([1, 44])
idx_cond should be size (1, t):  torch.Size([1, 44])
logits should be (1,t,v):  torch.Size([1, 44, 65])
mlogits should be (m,1,t,v):  torch.Size([5, 1, 44, 65])
idx_ntp should be (1, t):  torch.Size([1, 44])
idx should be (1, t+1):  torch.Size([1, 45])
mlogits should be (m,v):  torch.Size([5, 65])
idx_m_topk should be (m,k):  torch.Size([5, 2])


  combinations = combinations.T.reshape(-1, m)


mcomb should be (k^m,m):  torch.Size([32, 5])
idx_rep should be (k^m,t+1):  torch.Size([32, 45])
input should be (k^m,t+1+m):  torch.Size([32, 50])
input_cond should be (k^m,t):  torch.Size([32, 50])
logits should be (k^m,t,v):  torch.Size([32, 50, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 50, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 50, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
 

input should be (k^m,t+1+m):  torch.Size([32, 64])
input_cond should be (k^m,t):  torch.Size([32, 64])
logits should be (k^m,t,v):  torch.Size([32, 64, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 64, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 64, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='mps:0')
max_val:  1
max_idx_row:  0
max_idx_col:  15
id

idx_rep should be (k^m,t+1):  torch.Size([32, 76])
input should be (k^m,t+1+m):  torch.Size([32, 81])
input_cond should be (k^m,t):  torch.Size([32, 81])
logits should be (k^m,t,v):  torch.Size([32, 81, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 81, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 81, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:

input should be (k^m,t+1+m):  torch.Size([32, 88])
input_cond should be (k^m,t):  torch.Size([32, 88])
logits should be (k^m,t,v):  torch.Size([32, 88, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 88, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 88, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  31
max_idx_col:  0
id

input should be (k^m,t+1+m):  torch.Size([32, 103])
input_cond should be (k^m,t):  torch.Size([32, 103])
logits should be (k^m,t,v):  torch.Size([32, 103, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 103, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 103, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  

input should be (k^m,t+1+m):  torch.Size([32, 118])
input_cond should be (k^m,t):  torch.Size([32, 118])
logits should be (k^m,t,v):  torch.Size([32, 118, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 118, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 118, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[3, 3, 4,  ..., 1, 1, 1],
        [3, 3, 4,  ..., 1, 1, 1],
        [3, 3, 5,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='mps:0')
max_val:  5
max_idx_row:  31
max_idx_col: 

idx_rep should be (k^m,t+1):  torch.Size([32, 125])
input should be (k^m,t+1+m):  torch.Size([32, 130])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  31
idx:  torch.Size([1, 134]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  31
idx:  torch.Size([1, 143]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  31
idx:  torch.Size([1, 153]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

idx_rep should be (k^m,t+1):  torch.Size([32, 163])
input should be (k^m,t+1+m):  torch.Size([32, 168])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device

idx_rep should be (k^m,t+1):  torch.Size([32, 174])
input should be (k^m,t+1+m):  torch.Size([32, 179])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device

idx_rep should be (k^m,t+1):  torch.Size([32, 183])
input should be (k^m,t+1+m):  torch.Size([32, 188])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 3, 3, 3],
        [1, 1, 1,  ..., 3, 3, 3],
        [1, 1, 1,  ..., 3, 3, 3]], device

input should be (k^m,t+1+m):  torch.Size([32, 202])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  31
max_idx_col: 

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[4, 5, 3,  ..., 0, 0, 0],
        [4, 5, 3,  ..., 0, 0, 0],
        [4, 4, 3,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='mps:0')
max_val:  5
max_idx_row:  1
max_idx_col:  1
idx:  torch.Size([1, 206]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1,

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  31
idx:  torch.Size([1, 228]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

idx:  torch.Size([1, 235]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0,  0, 24, 27, 30,
         16,  1,  0, 24, 33, 15, 21, 27, 10,  0, 27,  6,  1, 46, 43, 61, 43, 43,
         54,  1,  1, 58, 46, 43,  1, 51, 39, 52,  1, 53, 44,  1,  1,  1,  1, 58,
         53, 56, 43, 43, 43,  1, 58, 46, 39, 52,  1,  1, 53, 44, 56, 47, 43, 52,
         42,  6,  0, 13, 52,  1, 47, 46, 53, 43, 42,  1, 58, 46, 43,  1, 57, 58,
         56, 43, 52, 45, 43, 58,  1, 53, 44,  1, 58, 46, 43,  1, 57, 58, 56, 43,
         49, 45, 43,  1, 53, 53,  1, 43, 39, 56, 58,  0, 32, 46,  1,  1,  1,  1,
         58, 46, 43,  1, 57, 58, 39, 43, 43, 43, 43,  1, 53, 44,  1, 58, 46, 43,
          1, 57, 58, 56, 43, 43,  1,  1, 53, 44,  1, 58, 46, 43,  1, 57, 58, 56,
         43, 52, 45, 43,  0,  0, 18, 47, 56, 57, 58,  1, 31, 43, 56, 60, 47, 52,
 

input should be (k^m,t+1+m):  torch.Size([32, 253])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='mps:0')
max_val:  0
max_idx_row:  0
max_idx_col:  

input should be (k^m,t+1+m):  torch.Size([32, 263])
input_cond should be (k^m,t):  torch.Size([32, 128])
logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]], device='mps:0')
max_val:  1
max_idx_row:  16
max_idx_col: 

idx:  torch.Size([1, 265]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0,  0, 24, 27, 30,
         16,  1,  0, 24, 33, 15, 21, 27, 10,  0, 27,  6,  1, 46, 43, 61, 43, 43,
         54,  1,  1, 58, 46, 43,  1, 51, 39, 52,  1, 53, 44,  1,  1,  1,  1, 58,
         53, 56, 43, 43, 43,  1, 58, 46, 39, 52,  1,  1, 53, 44, 56, 47, 43, 52,
         42,  6,  0, 13, 52,  1, 47, 46, 53, 43, 42,  1, 58, 46, 43,  1, 57, 58,
         56, 43, 52, 45, 43, 58,  1, 53, 44,  1, 58, 46, 43,  1, 57, 58, 56, 43,
         49, 45, 43,  1, 53, 53,  1, 43, 39, 56, 58,  0, 32, 46,  1,  1,  1,  1,
         58, 46, 43,  1, 57, 58, 39, 43, 43, 43, 43,  1, 53, 44,  1, 58, 46, 43,
          1, 57, 58, 56, 43, 43,  1,  1, 53, 44,  1, 58, 46, 43,  1, 57, 58, 56,
         43, 52, 45, 43,  0,  0, 18, 47, 56, 57, 58,  1, 31, 43, 56, 60, 47, 52,
 

idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='mps:0')
max_val:  2
max_idx_row:  31
max_idx_col:  0
idx:  torch.Size([1, 274]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 12,  0,  0,  0, 24, 27, 30,
         16,  1,  0, 24, 33, 15, 21, 27, 10,  0, 27,  6,  1, 46,

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='mps:0')
max_val:  3
max_idx_row:  0
max_idx_col:  7
idx:  torch.Size([1, 283]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1,

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='mps:0')
max_val:  2
max_idx_row:  31
max_idx_col:  0
idx:  torch.Size([1, 293]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        ...,
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0]], device='mps:0')
max_val:  2
max_idx_row:  0
max_idx_col:  7
idx:  torch.Size([1, 299]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1,

logits should be (k^m,t,v):  torch.Size([32, 128, 65])
mlogits should be (m,k^m,t,v):  torch.Size([5, 32, 128, 65])
idx_ntp_topk should be (k^m, t,k):  torch.Size([32, 128, 2])
idx_check should be (k^m,m,k):  torch.Size([32, 5, 2])
idx_check_comb should be (k^m,k^m,m):  torch.Size([32, 32, 5])
mcomb_check should be (k^m,k^m,m):  torch.Size([32, 32, 5])
match_tensor:  torch.Size([32, 32, 5])
pad:  torch.Size([32, 32, 1])
padded_tensor:  torch.Size([32, 32, 6])
zero_positions:  torch.Size([32, 32])
zero_positions:  torch.Size([32, 32])
range_tensor:  torch.Size([32, 32, 5])
mask:  torch.Size([32, 32, 5])
result:
 tensor([[2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        [2, 2, 2,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='mps:0')
max_val:  3
max_idx_row:  0
max_idx_col:  15
idx:  torch.Size([1, 308]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1

In [29]:
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print(decode(output[0].tolist()))

tokens per inference:  2.69
JULIET:
O Romeo, Romeo! wherefore art thou Rome?


LORD 
LUCIO:
O, heweep  the man of    toreee than  ofriend,
An ihoed the strenget of the strekge oo eart
Th    the staeeee of the stree  of the strenge

First Servinr: are tou thie?

Clown:
What  is,moe?
WMaster the   trualo ooa that  we haav 
shall both the  se
