In [1]:
import sys
sys.path.append('/Users/tunadorable/local-repos/learning_medusa/venv/lib/python3.11/site-packages')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time

In [3]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [4]:
# hyperparameters
b = 24 # how many independent sequences will we process in parallel?
t = 128 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
lr = 3e-4
eval_iters = 20
d = 128
h = 8
l = 8
dropout = 0.2
l2 = 0.01

m = 3
medusa_discount = torch.tensor(0.8).to(device)

In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
v = len(chars)
print(chars, v)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [8]:
# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - t - m, (b,))
    x = torch.stack([data[i:i+t] for i in ix])
    y = torch.stack([torch.stack([data[i+1+j:i+t+1+j] for i in ix]) for j in range(m+1)])
    return x.to(device), y.to(device)

In [10]:
x,y = get_batch('train')
print("x[0] ", x[0].shape, "\n", x[0])
print("y[:,0,...] ", y[:,0,...].shape, "\n", y[:,0,...])

x[0]  torch.Size([128]) 
 tensor([46, 43,  1, 61, 39, 57,  1, 39,  1, 51, 39, 52,  7, 41, 46, 47, 50, 42,
         0, 58, 46, 39, 52,  1, 52, 53, 61,  1, 47, 52,  1, 44, 47, 56, 57, 58,
         1, 57, 43, 43, 47, 52, 45,  1, 46, 43,  1, 46, 39, 42,  1, 54, 56, 53,
        60, 43, 42,  1, 46, 47, 51, 57, 43, 50, 44,  1, 39,  0, 51, 39, 52,  8,
         0,  0, 34, 21, 30, 19, 21, 24, 21, 13, 10,  0, 14, 59, 58,  1, 46, 39,
        42,  1, 46, 43,  1, 42, 47, 43, 42,  1, 47, 52,  1, 58, 46, 43,  1, 40,
        59, 57, 47, 52, 43, 57, 57,  6,  1, 51, 39, 42, 39, 51, 11,  1, 46, 53,
        61,  1], device='mps:0')
y[:,0,...]  torch.Size([4, 128]) 
 tensor([[43,  1, 61, 39, 57,  1, 39,  1, 51, 39, 52,  7, 41, 46, 47, 50, 42,  0,
         58, 46, 39, 52,  1, 52, 53, 61,  1, 47, 52,  1, 44, 47, 56, 57, 58,  1,
         57, 43, 43, 47, 52, 45,  1, 46, 43,  1, 46, 39, 42,  1, 54, 56, 53, 60,
         43, 42,  1, 46, 47, 51, 57, 43, 50, 44,  1, 39,  0, 51, 39, 52,  8,  0,
          0, 34, 21, 3

In [11]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss, medusa_logits = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [12]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 4 * d),
            nn.ReLU(), 
            nn.Linear(4 * d, d),
            nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

In [13]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(d, head_size, bias=False)
        self.query = nn.Linear(d, head_size, bias=False)
        self.value = nn.Linear(d, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(t, t)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        b,t,d = x.shape
        k = self.key(x)   # (b,t,d/h)
        q = self.query(x) # (b,t,d/h)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (b, t, d/h) @ (b, d/h, t) -> (b, t, t)
        wei = wei.masked_fill(self.tril[:t, :t] == 0, float('-inf')) # (b, t, t)
        wei = F.softmax(wei, dim=-1) # (b, t, t)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (b,t,d/h)
        out = wei @ v # (b, t, t) @ (b, t, d/h) -> (b, t, d/h)
        return out

In [14]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, h, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(h)])
        self.proj = nn.Linear(head_size * h, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [15]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d, h):
        # d: embedding dimension, h: the number of heads we'd like
        super().__init__()
        head_size = d // h
        self.sa = MultiHeadAttention(h, head_size)
        self.ffwd = FeedFoward(d)
        self.ln1 = nn.LayerNorm(d)
        self.ln2 = nn.LayerNorm(d)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [16]:
class snake(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(d,d)
        self.relu = nn.ReLU() # actual paper uses SiLU bc they build off Llama
        self.w2 = nn.Linear(d,v)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.w2(self.dropout(self.relu(self.w1(x))+x)) # outputs logits shape (b,t,v)

In [17]:
class medusaGPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(v, d)
        self.position_embedding_table = nn.Embedding(t, d)
        self.blocks = nn.Sequential(*[Block(d, h) for _ in range(l)])
        self.ln_f = nn.LayerNorm(d) # final layer norm
        self.lm_head = nn.Linear(d, v)
        
        # Create a list of Medusa heads
        self.medusa_heads = nn.ModuleList([snake() for _ in range(m)])
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None, medusa_targets=None, verbose=False):
        b, t = idx.shape

        # idx and targets are both (b,t) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (b,t,d)
        pos_emb = self.position_embedding_table(torch.arange(t, device=device)) # (t,d)
        x = tok_emb + pos_emb # (b,t,d)
        x = self.ln_f(self.blocks(x)) # (b,t,d) -> (b,t,d)
        
        logits = self.lm_head(x) # (b,t,d)@(d,v)=(b,t,v)
        
        # Apply each snake head to x and store the results
        medusa_logits = torch.stack([head(x) for head in self.medusa_heads], dim=0)
        
        if verbose:
            print("logits: ", logits.shape, logits)
            print("medusa_logits: ", medusa_logits.shape, medusa_logits)
        
        if targets is None:
            loss = None
            medusa_loss = None
        else:
            m, b, t, v = medusa_logits.shape
            logits = logits.view(b*t, v)
            targets0 = targets[0].view(b*t)
            loss = F.cross_entropy(logits, targets0)

            # todo: make this all one line
            mloss = []
            for i in range(m):
                mloss.append(F.cross_entropy(medusa_logits[i].view(b*t,v), targets[i+1].view(b*t))*medusa_discount**(i+1))
            medusa_loss = torch.stack(mloss)
            # this might work but i haven't tested it \/
            #medusa_loss = torch.stack([F.cross_entropy(medusa_logits[i].view(b*t, v), targets[i+1].view(b*t)) * medusa_discount**(i+1) for i in range(m)])
            
            loss = loss + medusa_loss.sum()

        return logits, loss, medusa_logits

    def generate_gpt(self, idx, max_new_tokens, temperature=1.0):
        # idx is (b, t) array of indices in the current context
        #assert temperature >= 0
        
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -t:]
            
            # get the predictions
            logits, loss, mlogits = self(idx_cond)
            
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (b, d)
            
            # scale logits by the temperature
            logits = logits / (temperature+1e-10)
            
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (b, d)
            
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        
        return idx

# Training

if you don't want to do your own training just scroll down

In [18]:
model = medusaGPT()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e3, 'K parameters')

1691.14 K parameters


In [19]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2)

In [20]:
start_time = time.time()
for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train')
    # train
    logits, loss, mlogits = model(xb, yb)

    # L1 Regularization
    l1_norm = sum(p.abs().sum() for p in model.parameters())

    # Total loss
    total_loss = loss + 0.01 * l1_norm
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

step 0: train loss 11.8995, val loss 11.9120, time elapsed: 1.95 seconds
step 100: train loss 8.8519, val loss 8.8598, time elapsed: 46.52 seconds
step 200: train loss 8.5150, val loss 8.5164, time elapsed: 90.72 seconds
step 300: train loss 8.3617, val loss 8.3665, time elapsed: 134.74 seconds
step 400: train loss 8.2472, val loss 8.2383, time elapsed: 178.67 seconds
step 500: train loss 8.1164, val loss 8.1313, time elapsed: 222.43 seconds
step 600: train loss 7.9587, val loss 8.0067, time elapsed: 266.19 seconds
step 700: train loss 7.8637, val loss 7.9270, time elapsed: 309.95 seconds
step 800: train loss 7.7623, val loss 7.7944, time elapsed: 353.73 seconds
step 900: train loss 7.6077, val loss 7.7277, time elapsed: 397.29 seconds
step 1000: train loss 7.5329, val loss 7.6664, time elapsed: 441.00 seconds
step 1100: train loss 7.4500, val loss 7.5704, time elapsed: 484.66 seconds
step 1200: train loss 7.3449, val loss 7.5147, time elapsed: 528.13 seconds
step 1300: train loss 7.26

#### save the trained model

In [23]:
torch.save(model.state_dict(), f'models/medusa_b{b}_t{t}_d{d}_h{h}_l{l}_lr{lr}_drop{dropout}_l2:{l2}_m{m}_mdiscount{medusa_discount:.2f}_{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Load a saved model

In [18]:
# Assuming `MyModel` is the class of your model
model = medusaGPT().to(device)  # Initialize a model with the same architecture

# Load the saved state dictionary
model.load_state_dict(torch.load('models/medusa_b24_t128_d128_h8_l8_lr0.0003_drop0.2_m3_mdiscount0.8_2024-01-24|17-36-02.pth'))

# If you plan to continue training the model, switch to training mode
#model.train()

# If you only plan to do inference, switch to evaluation mode
model.eval()

medusaGPT(
  (token_embedding_table): Embedding(65, 128)
  (position_embedding_table): Embedding(128, 128)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=128, out_features=16, bias=False)
            (query): Linear(in_features=128, out_features=16, bias=False)
            (value): Linear(in_features=128, out_features=16, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
   

## Regular GPT Inference

for some reason if i try to run this before doing my later functions than the kernel crashes. I recommend restarting with a fresh kernel for those

In [117]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output = model.generate_gpt(context_tensor, max_new_tokens=218)
output_str = decode(output[0].tolist())
print(output_str)

JULIET:
O Romeo, Romeo! wherefore art thou Richard?'

SICINIUS:
Why, didst have sone the woll reesign:
He fault is this mighty, terror, day and by so;
But faries, and he should seak to more fellow ils.

WARWICK:
Will you not been verlain action to hell,
By the l
CPU times: user 16.9 s, sys: 1.37 s, total: 18.2 s
Wall time: 19.4 s


# Medusa's first sister Stheno (the aggressive one)

ChatGPT said:
Her name translates to "strength" or "forceful". Stheno was the eldest and most fierce of the sisters, known for her strength and ferocity. 

So I guess since this generation is strictly doing greedy decoding i'll name it after her

In [20]:
def generate_Stheno(model, idx, max_runs):
    # Ensure idx is a single sequence
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    print("idx: ", idx.shape)
    
    # crop idx to the last block_size tokens
    input_cond = idx[:, -t:] # should this be just t???? previosly i had it as -(t+m)
    print("idx_cond should be size (1, t): ", input_cond.shape)
          
    # get the predictions
    logits, loss, mlogits = model(input_cond) # (1,t,v), int, and (m,1,t,v)
    print("logits should be (1,t,v): ", logits.shape)
    print("mlogits should be (m,1,t,v): ", mlogits.shape)
    
    # we only want the medusa preditions for the newest future timestep
    mlogits = mlogits[...,-1,:] # becomes (m,1,v)
    mlogits = mlogits.squeeze(dim=1) # (m,v)
    print("mlogits should be (m,v): ", mlogits.shape)
    
    # actual medusa does a whole proabability attention mask tree thing, but here we'll assume greedy decoding on snake heads
    idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (1,m)
    print("idx_m_prev should be (1,m): ", idx_m_prev.shape)
    # we name it _prev since it'll be used on the next loop
    
    # medusa requires greedy decoding for the first regular t+1'th token
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1, t)
    print("idx_ntp should be (1, t): ", idx_ntp.shape)
    
    # append sampled index to the running sequence
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) # (1,t)+(1,1) -> (1, t+1)
    print("idx should be (1, t+1): ", idx.shape)

    # keep track of how many tokens we get per model inference
    tok_per_inf = [1]
    
    for _ in range(max_runs-1): # -1 since one iteration was done above
        ###### I think this is where i begin the looped part. above stuff should be before the for loop
        
        # so now we have idx shape (1,t+1) of context tokens plus 1 predicted token
        # and idx_m_prev shape  (1,m) of speculative tokens to check against later & include some
        # want to know whether model actually would've predicted the tokens in idx_m so we need to run it with them & compare
        
        # first we construct a tensor composed of the initial context, the single ntp token, and then the speculative tokens
        inp = torch.cat((idx, idx_m_prev), dim=1) # (1,t+1) & (1,m) -> (1,t+1+m)
        print("input should be (1,t+1+m): ", inp.shape)
        
        # but since the model can't take more than t inputs
        # note: t is max content limit, but input might be shorter. if so that's fine
        input_cond = inp[:, -t:] 
        # (1,t+1+m) -> (1,t) where first t-1-m are prior context, the t-m'th is ntp, & last m are candidates
        print("input_cond should be (1,t): ", input_cond.shape)
        
        # then we pass it in
        logits, loss, mlogits = model(input_cond) # (1,t,v), tensor of a single float, & (m,1,t,v)
        
        # we're just greedy decoding so there's no need to softmax
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1,t,v) -> (1, t)
        print("idx_ntp should be (1, t): ", idx_ntp.shape)
        
        # now for comparison's sake we need to ignore the indices that were part of the context
        # the regular ntp prediction can also be ignored
        # we just need to know the number of previous turn's medusa preds to keep
        idx_check = idx_ntp[:,-(m+1):-1] # (1,t) -> (1,m)
        print("idx_check should be (1,m): ", idx_check.shape, idx_check)
        
        # check whether they match
        match_tensor = (idx_m_prev == idx_check).int()
        #(1,m) of ints & (1,m) of ints -> (1,m) of 1's and 0's
        print("match_tensor should be (1,m) of 1's and 0's: ", match_tensor.shape, match_tensor)
        
        ########### then define our "definitely good" indices as those which are 1's for the similarity check
        # We invert the tensor so that all zeros & 1's flip
        # Find the first 1 in each row. The max function returns the first occurrence of the maximum value.
        # We add one extra zero at the end of each row to handle rows that contain no zeros.
        pad = torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype, device=device)
        print("pad: ", pad.shape, pad)
        padded_tensor = torch.cat((1 - match_tensor, pad), dim=1)
        print("padded_tensor: ", padded_tensor.shape, padded_tensor)
        zero_positions = padded_tensor.argmax(dim=1)
        
        # Adjust indices where the last position is selected (meaning there were no zeros)
        zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)
        print("zero_positions: ", zero_positions.shape, zero_positions)
        
        # Create a range tensor
        range_tensor = torch.arange(match_tensor.size(1), device=device).unsqueeze(0).expand_as(match_tensor)
        print("range_tensor: ", range_tensor.shape, range_tensor)
        
        # Create a mask where each element is 1 if it is before the first zero in its row
        mask = range_tensor < zero_positions.unsqueeze(1)
        print("mask: ", mask.shape, mask)
        
        # Apply the mask and sum along each row
        result = (match_tensor * mask).sum(dim=1).item()
        print("result: ", result)

        tok_per_inf.append(result+1)
        
        # so now i've got result which is an integer between 0 and m (inclusive)
        # i need to use it as an index on the actual medusa head outputs of interest and then append that to our sequence
        # but we also get to use the regular ntp prediction right after any of last run's accepted medusa head predictions
        print("idx: ", idx.shape, idx)
        idx_m_prev = idx_m_prev[:,:result]
        print("idx_m_prev should be (1,result): ", idx_m_prev.shape, idx_m_prev)
        idx_ntp = idx_ntp[:,-1-m+result].unsqueeze(dim=0)
        print("idx_ntp should be (1,1): ", idx_ntp.shape, idx_ntp)
        idx = torch.cat((idx, idx_m_prev, idx_ntp),dim=1)
        print("idx should be (1,t+result+1: ", idx.shape, idx)
        
        # so now we have idx of size (1,t+1+result)
        # and idx_m of size (1,m)
        # wait but there's a problem. If we didn't use the largest result=m then 
        #   idx_m is for items too far into the future. we need to pass forward an idx_m that's relevant
        # ngl i'm too tired to think abou tit so i'm assuming i can use the same index i used for idx_ntp
        mlogits = mlogits[...,-1-m+result,:].squeeze(dim=1) # (m,1,t,v) -> (m,1,v)
        print("mlogits: ", mlogits.shape)
        idx_m_prev = torch.argmax(mlogits, dim=-1, keepdim=True).t() # (m,1,v) -> (1,m)
        print("idx_m_prev: ", idx_m_prev.shape, idx_m_prev)
        # we name it _prev since it'll be used on the next loop
        
    return idx, tok_per_inf

In [21]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Stheno(model, context_tensor, max_runs=100)
output_str = decode(output[0].tolist())
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print(output_str)

idx:  torch.Size([1, 44])
idx_cond should be size (1, t):  torch.Size([1, 44])
logits should be (1,t,v):  torch.Size([1, 44, 65])
mlogits should be (m,1,t,v):  torch.Size([3, 1, 44, 65])
mlogits should be (m,v):  torch.Size([3, 65])
idx_m_prev should be (1,m):  torch.Size([1, 3])
idx_ntp should be (1, t):  torch.Size([1, 44])
idx should be (1, t+1):  torch.Size([1, 45])
input should be (1,t+1+m):  torch.Size([1, 48])
input_cond should be (1,t):  torch.Size([1, 48])
idx_ntp should be (1, t):  torch.Size([1, 48])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[51, 43, 53]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 

idx_ntp should be (1, t):  torch.Size([1, 67])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[ 1, 39, 46]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 64]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21]], device='mps:0')
idx_m_prev should be (1,resu

idx_ntp should be (1, t):  torch.Size([1, 76])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[43,  1, 57]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[ True,  True, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 73]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51]],

idx_ntp should be (1, t):  torch.Size([1, 84])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[53,  1, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[ True,  True, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 81]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 4

idx_ntp should be (1, t):  torch.Size([1, 93])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[60,  1, 53]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 90]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 4

idx:  torch.Size([1, 100]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46]], device='mps:0')
idx_m_prev should be (1,result):  torch.Size([1, 2]) tensor([[43,  1]], device='mps:0')
idx_ntp should be (1,1):  torch.Size([1, 1]) tensor([[41]], device='mps:0')
idx should be (1,t+result+1:  torch.Size([1, 103]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0

idx_m_prev should be (1,result):  torch.Size([1, 2]) tensor([[0, 0]], device='mps:0')
idx_ntp should be (1,1):  torch.Size([1, 1]) tensor([[15]], device='mps:0')
idx should be (1,t+result+1:  torch.Size([1, 111]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15]], device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[33, 30, 21]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 114])
input_cond should be (1,t):  torch.Size([1, 114])
idx_ntp should be (1, t):  torch.Size([1, 114])
idx_check shou

idx_ntp should be (1, t):  torch.Size([1, 127])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[46, 46, 51]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 124]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51,

idx should be (1,t+result+1:  torch.Size([1, 133]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52]], device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[1, 1, 1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 136])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[ 1, 58, 58]], device='mps:

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[56,  1,  1]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 140]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51,

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[43,  1, 57]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([2], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[ True,  True, False]], device='mps:0')
result:  2
idx:  torch.Size([1, 151]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51,

idx should be (1,t+result+1:  torch.Size([1, 161]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57]],
       device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[0, 0, 1]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 164])
input_cond should be (1,t):  torch.Size([1, 128])
id

idx should be (1,t+result+1:  torch.Size([1, 173]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33]], device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[31, 10,  0]], device='mps:0')
input should be (1,t+1+m):  torch.Size([1, 176])

idx should be (1,t+result+1:  torch.Size([1, 179]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46]],
       device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[39,  1, 43]], device='mps:0')
input should be (

idx should be (1,t+result+1:  torch.Size([1, 187]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52]], device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.Size([1, 3]) tensor([[1, 1, 1]], de

idx should be (1,t+result+1:  torch.Size([1, 194]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43]],
       device='mps:0')
mlogits:  torch.Size([3, 65])
idx_m_prev:  torch.

idx:  torch.Size([1, 201]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53]], device='mps:0')
idx_m_prev should be (1,result):  torch.Siz

input should be (1,t+1+m):  torch.Size([1, 213])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[60, 47, 53]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([1], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[ True, False, False]], device='mps:0')
result:  1
idx:  torch.Size([1, 210]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27

idx:  torch.Size([1, 216]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  1]],

input should be (1,t+1+m):  torch.Size([1, 227])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[53, 56, 57]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 1, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 0, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 224]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27

zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 232]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30,

idx_m_prev should be (1,result):  torch.Size([1, 2]) tensor([[10,  0]], device='mps:0')
idx_ntp should be (1,1):  torch.Size([1, 1]) tensor([[21]], device='mps:0')
idx should be (1,t+result+1:  torch.Size([1, 244]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31,

idx should be (1,t+result+1:  torch.Size([1, 248]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43, 57,  8,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 57, 43, 43, 52,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41,
         43,  1, 53, 44,  1, 58, 46, 43,  1, 57, 43, 56, 6

input should be (1,t+1+m):  torch.Size([1, 257])
input_cond should be (1,t):  torch.Size([1, 128])
idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[ 1, 58, 58]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[1, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[0, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([1], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[ True, False, False]], device='mps:0')
result:  1
idx:  torch.Size([1, 254]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27

idx_ntp should be (1, t):  torch.Size([1, 128])
idx_check should be (1,m):  torch.Size([1, 3]) tensor([[43, 59, 56]], device='mps:0')
match_tensor should be (1,m) of 1's and 0's:  torch.Size([1, 3]) tensor([[0, 0, 0]], device='mps:0', dtype=torch.int32)
pad:  torch.Size([1, 1]) tensor([[1]], device='mps:0', dtype=torch.int32)
padded_tensor:  torch.Size([1, 4]) tensor([[1, 1, 1, 1]], device='mps:0', dtype=torch.int32)
zero_positions:  torch.Size([1]) tensor([0], device='mps:0')
range_tensor:  torch.Size([1, 3]) tensor([[0, 1, 2]], device='mps:0')
mask:  torch.Size([1, 3]) tensor([[False, False, False]], device='mps:0')
result:  0
idx:  torch.Size([1, 260]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51,

#### notice that it's restricted to greedy decoding which means the output quality is lower. If you compare the two passages you'll see that the NTP one has far less of a problem with repetition

# Medusa's second sister Euryale (the explorative one)


ChatGPT said

Her name means "far-roaming" in Greek. Euryale was known for her loud crying or bellowing. If your architecture is meant to explore a wide range of possibilities or to "roam" extensively through a dataset, the name Euryale might be suitable. Additionally, if your architecture involves a broad or far-reaching search strategy or is notable for 'broadcasting' its findings extensively (analogous to loud crying), Euryale could be an apt choice.



The goal here with Euryale is effectively to bring probabalistic decoding back to Stheno. My hope is that this will only require more ram and not result in any added latency, but we'll see. The basic idea is to use topk instead of argmax, construct every single possible candidate sequence, and then select whichever is longest. If anything i think the re-incorporation of topk results will *maybe* add a speed increase if they give us longer accepted sequences (although likely not as fast as the attention-based mechanism used in actual Medusa)

In [110]:
def combinations(tensor):
    # Get the shape of the tensor
    m, k = tensor.shape

    # Create index grids for each dimension
    index_grids = [torch.arange(k) for _ in range(m)]

    # Create a meshgrid of indices with reverse order
    mesh_indices = torch.meshgrid(index_grids[::-1], indexing="ij")

    # Use advanced indexing to create the combination tensor
    # Reverse the order of tensor indexing to match the reversed mesh_indices
    combinations = torch.cat([tensor[m-1-i][mesh_indices[i]].unsqueeze(0) for i in range(m)], dim=0)

    # Reshape to the desired shape (k^m, m) and reverse the columns for the final output
    combinations = combinations.T.reshape(-1, m)
    combinations = combinations.flip(dims=[1])

    return combinations

def generate_Euryale(model, idx, max_runs, k=2):
    # Ensure idx is a single sequence
    assert idx.size(0) == 1, "idx must be of size (1, t)"
    print("idx: ", idx.shape)
    
    # crop idx to the last block_size tokens
    input_cond = idx[:, -t:] # should this be just t???? previosly i had it as -(t+m)
    print("idx_cond should be size (1, t): ", input_cond.shape)
    
    # get the predictions
    logits, loss, mlogits = model(input_cond) # (1,t,v), int, and (m,1,t,v)
    print("logits should be (1,t,v): ", logits.shape)
    print("mlogits should be (m,1,t,v): ", mlogits.shape)
    
    # we only want the medusa preditions for the newest future timestep
    mlogits = mlogits[...,-1,:] # becomes (m,1,v)
    mlogits = mlogits.squeeze(dim=1) # (m,v)
    print("mlogits should be (m,v): ", mlogits.shape)
    
    # actual medusa does a whole proabability attention mask tree thing
    # here we're just gonna mess around with the topk options
    idx_m_topk = torch.topk(mlogits, k, dim=-1, largest=True).indices
    print("idx_m_topk should be (m,k): ", idx_m_topk.shape, "\n", idx_m_topk)
    
    # medusa requires greedy decoding for the first regular t+1'th token
    idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (1, t)
    print("idx_ntp should be (1, t): ", idx_ntp.shape)
    
    # append sampled index to the running sequence
    idx = torch.cat((idx, idx_ntp[:,-1].unsqueeze(dim=0)), dim=1) # (1,t)+(1,1) -> (1, t+1)
    print("idx should be (1, t+1): ", idx.shape)

    # keep track of how many tokens we get per model inference
    tok_per_inf = [1]
    
    for _ in range(max_runs-1): # -1 since one iteration was done above
        
        # so now we have idx shape (1,t+1) of context tokens plus 1 predicted token
        # and idx_m_topk shape (m,k) of topk speculative tokens
        # want to know whether model actually would've predicted the tokens in idx_m so we need to run it with them & compare
        
        # get all the possible combinations of the topk results from each viper head
        mcomb = combinations(idx_m_topk) # (m, k) -> (k^m,m) of all k^m combinations
        print("mcomb should be (k^m,m): ", mcomb.shape, "\n", mcomb)
        
        # expanding out input sequence to take advantage of batched inference
        idx_rep = idx.repeat(k**m,1) # (1,t+1) -> (k^m, t+1)
        print("idx_rep should be (k^m,t+1): ", idx_rep.shape)
        
        # first we construct a tensor composed of the initial context, the single ntp token, and then the speculative tokens
        # now this k^m dimension is effecitvely our batch size b as if we were doing batch inference
        inp = torch.cat((idx_rep, mcomb), dim=1) # (k^m,t+1) & (k^m,m) -> (k^m,t+1+m)
        print("input should be (k^m,t+1+m): ", inp.shape)
        
        # but since the model can't take more than t inputs
        # note: t is max content limit, but input might be shorter. if so that's fine
        input_cond = inp[:, -t:] 
        # (k^m,t+1+m) -> (k^m,t) where first t-1-m are prior context, the t-m'th is ntp, & last m are candidates
        print("input_cond should be (k^m,t): ", input_cond.shape, "\n", input_cond[:,-2*m:])
        
        # then we pass it in
        logits, loss, mlogits = model(input_cond) # (k^m,t,v), tensor of a single float, & (m,k^m,t,v)
        
        # we're just greedy decoding so there's no need to softmax
        idx_ntp = torch.argmax(logits, dim=-1, keepdim=True).squeeze(dim=2) # (k^m,t,v) -> (k^m, t)
        print("idx_ntp should be (k^m, t): ", idx_ntp.shape, "\n", idx_ntp[:,-2*m:])
        
        # now for comparison's sake we need to ignore the indices that were part of the context
        # the regular ntp prediction can also be ignored
        # we just need to know the number of previous turn's medusa preds to keep
        idx_check = idx_ntp[:,-(m+1):-1] # (k^m,t) -> (k^m,m)
        print("idx_check should be (k^m,m): ", idx_check.shape, "\n", idx_check)
        
        # check whether they match
        match_tensor = (mcomb == idx_check).int() #(k^m,m) & (k^m,m) of ints -> (k^m,m) of 1's and 0's
        print("match_tensor should be (k^m,m) of 1's and 0's: ", match_tensor.shape, "\n", match_tensor)
        
        ########### then define our "definitely good" indices as those which are 1's for the similarity check
        # We invert the tensor so that all zeros & 1's flip
        # Find the first 1 in each row. The max function returns the first occurrence of the maximum value.
        # We add one extra zero at the end of each row to handle rows that contain no zeros.
        pad = torch.ones(match_tensor.size(0), 1, dtype=match_tensor.dtype, device=device)
        print("pad: ", pad.shape)
        padded_tensor = torch.cat((1 - match_tensor, pad), dim=1)
        print("padded_tensor: ", padded_tensor.shape)
        zero_positions = padded_tensor.argmax(dim=1)
        
        # Adjust indices where the last position is selected (meaning there were no zeros)
        zero_positions[zero_positions >= match_tensor.size(1)] = match_tensor.size(1)
        print("zero_positions: ", zero_positions.shape, zero_positions)
        
        # Create a range tensor
        range_tensor = torch.arange(match_tensor.size(1), device=device).unsqueeze(0).expand_as(match_tensor)
        print("range_tensor: ", range_tensor.shape, range_tensor)
        
        # Create a mask where each element is 1 if it is before the first zero in its row
        mask = range_tensor < zero_positions.unsqueeze(1)
        print("mask: ", mask.shape, mask)
        
        # Apply the mask and sum along each row
        result = (match_tensor * mask).sum(dim=1)
        print("result: ", result)
        
        max_val, reversed_max_idx = torch.max(result.flip(dims=[0]), 0)
        max_idx = len(result) - 1 - reversed_max_idx
        print("max_val: ", max_val)
        print("reversed_max_idx: ", reversed_max_idx)
        print("max_idx: ", max_idx)
        # we reverse bc .max() always selects the first max, and the last max should
        # correspond to a string of tokens with smaller p-values, aka higher temperature

        tok_per_inf.append(max_val.item()+1)
        
        # so now i've got max_val which is an integer between 0 and m (inclusive)
        # and max_idx which is the index of the best sequence in the k^m batch 
        # need to use max_idx & max_val to get the actual medusa head outputs of interest and then append that to our sequence
        # but we also get to use the regular ntp prediction right after any of last run's accepted medusa head predictions
        print("idx: ", idx.shape, idx)
        
        idx_m = mcomb[max_idx, :max_val].unsqueeze(0) # (k^m,m) -> (1,max_val)
        print("idx_m should be (1,max_val): ", idx_m.shape, idx_m)
        
        idx_ntp = idx_ntp[max_idx,-1-m+max_val].unsqueeze(0).unsqueeze(0) # (k^m, t) -> (1,1)
        print("idx_ntp should be (1,1): ", idx_ntp.shape, idx_ntp)
        
        idx = torch.cat((idx, idx_m, idx_ntp),dim=1) # (1,t), (1,1), (1,max_val) -> (1,t+1+max_val)
        print("idx should be (1,t+max_val+1): ", idx.shape, idx)
        
        # wait but there's a problem. If we didn't use the largest max_val=m then 
        #     idx_m_topk is for items too far into the future. we need to pass forward 
        #     an idx_m_topk that's relevant
        # ngl i'm too tired to think about it so i'm assuming i can use the same index i used for idx_ntp
        mlogits = mlogits[:,max_idx,-1-m+max_val,:].unsqueeze(dim=1) # (m,k^m,t,v) -> (m,1,v)
        mlogits = mlogits.squeeze(dim=1) # (m,v)
        print("mlogits should be (m,v): ", mlogits.shape)
        
        ###### USE THIS INSTEAD OF idx_m_prev
        idx_m_topk = torch.topk(mlogits, k, dim=-1, largest=True).indices # (m,1,v) -> (m,k)
        print("idx_m_topk should be (m,k): ", idx_m_topk.shape)
            
    return idx, tok_per_inf

In [111]:
print(decode([53, 51, 43, 53, 12]))
decode([58])

omeo?


't'

In [115]:
%%time
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R"#"who fart thou b"
context_tensor = torch.tensor([encode(input_str)], dtype=torch.long, device=device)
output, tok_per_inf = generate_Euryale(model, context_tensor, max_runs=100, k=4)

idx:  torch.Size([1, 44])
idx_cond should be size (1, t):  torch.Size([1, 44])
logits should be (1,t,v):  torch.Size([1, 44, 65])
mlogits should be (m,1,t,v):  torch.Size([3, 1, 44, 65])
mlogits should be (m,v):  torch.Size([3, 65])
idx_m_topk should be (m,k):  torch.Size([3, 4]) 
 tensor([[51, 41, 57, 58],
        [43, 46, 47, 39],
        [39, 53,  1, 52]], device='mps:0')
idx_ntp should be (1, t):  torch.Size([1, 44])
idx should be (1, t+1):  torch.Size([1, 45])
mcomb should be (k^m,m):  torch.Size([64, 3]) 
 tensor([[51, 43, 39],
        [51, 43, 53],
        [51, 43,  1],
        [51, 43, 52],
        [51, 46, 39],
        [51, 46, 53],
        [51, 46,  1],
        [51, 46, 52],
        [51, 47, 39],
        [51, 47, 53],
        [51, 47,  1],
        [51, 47, 52],
        [51, 39, 39],
        [51, 39, 53],
        [51, 39,  1],
        [51, 39, 52],
        [41, 43, 39],
        [41, 43, 53],
        [41, 43,  1],
        [41, 43, 52],
        [41, 46, 39],
        [41, 46, 53]

zero_positions:  torch.Size([64]) tensor([2, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
range_tensor:  torch.Size([64, 3]) tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1,

idx_ntp should be (k^m, t):  torch.Size([64, 52]) 
 tensor([[53, 12,  0,  0, 15, 27],
        [53, 12,  0,  0, 15, 44],
        [53, 12,  0,  0, 15, 24],
        [53, 12,  0,  0, 15, 24],
        [53, 12,  0,  0, 52,  6],
        [53, 12,  0,  0, 52, 58],
        [53, 12,  0,  0, 52, 51],
        [53, 12,  0,  0, 52,  1],
        [53, 12,  0,  0, 46, 39],
        [53, 12,  0,  0, 46, 59],
        [53, 12,  0,  0, 46, 46],
        [53, 12,  0,  0, 46,  1],
        [53, 12,  0,  0, 44, 43],
        [53, 12,  0,  0, 44, 58],
        [53, 12,  0,  0, 44, 51],
        [53, 12,  0,  0, 44,  1],
        [53, 12,  0, 21,  0, 43],
        [53, 12,  0, 21,  0, 44],
        [53, 12,  0, 21,  0, 15],
        [53, 12,  0, 21,  0, 60],
        [53, 12,  0, 21, 52,  6],
        [53, 12,  0, 21, 52, 58],
        [53, 12,  0, 21, 52, 51],
        [53, 12,  0, 21, 52,  1],
        [53, 12,  0, 21, 46, 39],
        [53, 12,  0, 21, 46, 59],
        [53, 12,  0, 21, 46, 46],
        [53, 12,  0, 21, 46, 5

 tensor([[ 0,  0, 15, 33, 30, 17],
        [ 0,  0, 15, 33, 30, 21],
        [ 0,  0, 15, 33, 30, 33],
        [ 0,  0, 15, 33, 30, 32],
        [ 0,  0, 15, 33, 26, 17],
        [ 0,  0, 15, 33, 26, 21],
        [ 0,  0, 15, 33, 26, 33],
        [ 0,  0, 15, 33, 26, 32],
        [ 0,  0, 15, 33, 15, 17],
        [ 0,  0, 15, 33, 15, 21],
        [ 0,  0, 15, 33, 15, 33],
        [ 0,  0, 15, 33, 15, 32],
        [ 0,  0, 15, 33, 25, 17],
        [ 0,  0, 15, 33, 25, 21],
        [ 0,  0, 15, 33, 25, 33],
        [ 0,  0, 15, 33, 25, 32],
        [ 0,  0, 15, 27, 30, 17],
        [ 0,  0, 15, 27, 30, 21],
        [ 0,  0, 15, 27, 30, 33],
        [ 0,  0, 15, 27, 30, 32],
        [ 0,  0, 15, 27, 26, 17],
        [ 0,  0, 15, 27, 26, 21],
        [ 0,  0, 15, 27, 26, 33],
        [ 0,  0, 15, 27, 26, 32],
        [ 0,  0, 15, 27, 15, 17],
        [ 0,  0, 15, 27, 15, 21],
        [ 0,  0, 15, 27, 15, 33],
        [ 0,  0, 15, 27, 15, 32],
        [ 0,  0, 15, 27, 25, 17],
        [ 0, 

 tensor([[30, 21, 27, 24, 13, 26],
        [30, 21, 27, 24, 13, 30],
        [30, 21, 27, 24, 13, 10],
        [30, 21, 27, 24, 13, 27],
        [30, 21, 27, 24, 27, 26],
        [30, 21, 27, 24, 27, 30],
        [30, 21, 27, 24, 27, 10],
        [30, 21, 27, 24, 27, 27],
        [30, 21, 27, 24, 24, 26],
        [30, 21, 27, 24, 24, 30],
        [30, 21, 27, 24, 24, 10],
        [30, 21, 27, 24, 24, 27],
        [30, 21, 27, 24,  0, 26],
        [30, 21, 27, 24,  0, 30],
        [30, 21, 27, 24,  0, 10],
        [30, 21, 27, 24,  0, 27],
        [30, 21, 27, 20, 13, 26],
        [30, 21, 27, 20, 13, 30],
        [30, 21, 27, 20, 13, 10],
        [30, 21, 27, 20, 13, 27],
        [30, 21, 27, 20, 27, 26],
        [30, 21, 27, 20, 27, 30],
        [30, 21, 27, 20, 27, 10],
        [30, 21, 27, 20, 27, 27],
        [30, 21, 27, 20, 24, 26],
        [30, 21, 27, 20, 24, 30],
        [30, 21, 27, 20, 24, 10],
        [30, 21, 27, 20, 24, 27],
        [30, 21, 27, 20,  0, 26],
        [30, 

 tensor([[13, 26, 33, 31, 10,  0],
        [13, 26, 33, 31, 10, 10],
        [13, 26, 33, 31, 10, 26],
        [13, 26, 33, 31, 10, 32],
        [13, 26, 33, 31,  8,  0],
        [13, 26, 33, 31,  8, 10],
        [13, 26, 33, 31,  8, 26],
        [13, 26, 33, 31,  8, 32],
        [13, 26, 33, 31,  6,  0],
        [13, 26, 33, 31,  6, 10],
        [13, 26, 33, 31,  6, 26],
        [13, 26, 33, 31,  6, 32],
        [13, 26, 33, 31,  0,  0],
        [13, 26, 33, 31,  0, 10],
        [13, 26, 33, 31,  0, 26],
        [13, 26, 33, 31,  0, 32],
        [13, 26, 33, 24, 10,  0],
        [13, 26, 33, 24, 10, 10],
        [13, 26, 33, 24, 10, 26],
        [13, 26, 33, 24, 10, 32],
        [13, 26, 33, 24,  8,  0],
        [13, 26, 33, 24,  8, 10],
        [13, 26, 33, 24,  8, 26],
        [13, 26, 33, 24,  8, 32],
        [13, 26, 33, 24,  6,  0],
        [13, 26, 33, 24,  6, 10],
        [13, 26, 33, 24,  6, 26],
        [13, 26, 33, 24,  6, 32],
        [13, 26, 33, 24,  0,  0],
        [13, 

 tensor([[10,  0, 21, 46,  1,  1],
        [10,  0, 21, 46,  1, 58],
        [10,  0, 21, 46,  1, 43],
        [10,  0, 21, 46,  1, 50],
        [10,  0, 21, 46, 43,  1],
        [10,  0, 21, 46, 43, 58],
        [10,  0, 21, 46, 43, 43],
        [10,  0, 21, 46, 43, 50],
        [10,  0, 21, 46, 39,  1],
        [10,  0, 21, 46, 39, 58],
        [10,  0, 21, 46, 39, 43],
        [10,  0, 21, 46, 39, 50],
        [10,  0, 21, 46, 58,  1],
        [10,  0, 21, 46, 58, 58],
        [10,  0, 21, 46, 58, 43],
        [10,  0, 21, 46, 58, 50],
        [10,  0, 21, 53,  1,  1],
        [10,  0, 21, 53,  1, 58],
        [10,  0, 21, 53,  1, 43],
        [10,  0, 21, 53,  1, 50],
        [10,  0, 21, 53, 43,  1],
        [10,  0, 21, 53, 43, 58],
        [10,  0, 21, 53, 43, 43],
        [10,  0, 21, 53, 43, 50],
        [10,  0, 21, 53, 39,  1],
        [10,  0, 21, 53, 39, 58],
        [10,  0, 21, 53, 39, 43],
        [10,  0, 21, 53, 39, 50],
        [10,  0, 21, 53, 58,  1],
        [10, 

idx_rep should be (k^m,t+1):  torch.Size([64, 66])
input should be (k^m,t+1+m):  torch.Size([64, 69])
input_cond should be (k^m,t):  torch.Size([64, 69]) 
 tensor([[21,  1, 46, 39,  1, 43],
        [21,  1, 46, 39,  1, 50],
        [21,  1, 46, 39,  1,  1],
        [21,  1, 46, 39,  1, 52],
        [21,  1, 46, 39, 39, 43],
        [21,  1, 46, 39, 39, 50],
        [21,  1, 46, 39, 39,  1],
        [21,  1, 46, 39, 39, 52],
        [21,  1, 46, 39, 50, 43],
        [21,  1, 46, 39, 50, 50],
        [21,  1, 46, 39, 50,  1],
        [21,  1, 46, 39, 50, 52],
        [21,  1, 46, 39, 60, 43],
        [21,  1, 46, 39, 60, 50],
        [21,  1, 46, 39, 60,  1],
        [21,  1, 46, 39, 60, 52],
        [21,  1, 46, 47,  1, 43],
        [21,  1, 46, 47,  1, 50],
        [21,  1, 46, 47,  1,  1],
        [21,  1, 46, 47,  1, 52],
        [21,  1, 46, 47, 39, 43],
        [21,  1, 46, 47, 39, 50],
        [21,  1, 46, 47, 39,  1],
        [21,  1, 46, 47, 39, 52],
        [21,  1, 46, 47, 50,

idx_rep should be (k^m,t+1):  torch.Size([64, 70])
input should be (k^m,t+1+m):  torch.Size([64, 73])
input_cond should be (k^m,t):  torch.Size([64, 73]) 
 tensor([[60, 43,  1, 57, 53, 58],
        [60, 43,  1, 57, 53, 43],
        [60, 43,  1, 57, 53,  1],
        [60, 43,  1, 57, 53, 39],
        [60, 43,  1, 57, 43, 58],
        [60, 43,  1, 57, 43, 43],
        [60, 43,  1, 57, 43,  1],
        [60, 43,  1, 57, 43, 39],
        [60, 43,  1, 57, 39, 58],
        [60, 43,  1, 57, 39, 43],
        [60, 43,  1, 57, 39,  1],
        [60, 43,  1, 57, 39, 39],
        [60, 43,  1, 57, 46, 58],
        [60, 43,  1, 57, 46, 43],
        [60, 43,  1, 57, 46,  1],
        [60, 43,  1, 57, 46, 39],
        [60, 43,  1, 58, 53, 58],
        [60, 43,  1, 58, 53, 43],
        [60, 43,  1, 58, 53,  1],
        [60, 43,  1, 58, 53, 39],
        [60, 43,  1, 58, 43, 58],
        [60, 43,  1, 58, 43, 43],
        [60, 43,  1, 58, 43,  1],
        [60, 43,  1, 58, 43, 39],
        [60, 43,  1, 58, 39,

idx_rep should be (k^m,t+1):  torch.Size([64, 73])
input should be (k^m,t+1+m):  torch.Size([64, 76])
input_cond should be (k^m,t):  torch.Size([64, 76]) 
 tensor([[57, 53, 51, 43,  1,  1],
        [57, 53, 51, 43,  1, 58],
        [57, 53, 51, 43,  1, 53],
        [57, 53, 51, 43,  1, 43],
        [57, 53, 51, 43, 53,  1],
        [57, 53, 51, 43, 53, 58],
        [57, 53, 51, 43, 53, 53],
        [57, 53, 51, 43, 53, 43],
        [57, 53, 51, 43, 58,  1],
        [57, 53, 51, 43, 58, 58],
        [57, 53, 51, 43, 58, 53],
        [57, 53, 51, 43, 58, 43],
        [57, 53, 51, 43, 42,  1],
        [57, 53, 51, 43, 42, 58],
        [57, 53, 51, 43, 42, 53],
        [57, 53, 51, 43, 42, 43],
        [57, 53, 51,  1,  1,  1],
        [57, 53, 51,  1,  1, 58],
        [57, 53, 51,  1,  1, 53],
        [57, 53, 51,  1,  1, 43],
        [57, 53, 51,  1, 53,  1],
        [57, 53, 51,  1, 53, 58],
        [57, 53, 51,  1, 53, 53],
        [57, 53, 51,  1, 53, 43],
        [57, 53, 51,  1, 58,

input should be (k^m,t+1+m):  torch.Size([64, 79])
input_cond should be (k^m,t):  torch.Size([64, 79]) 
 tensor([[43,  1, 57, 53, 56, 43],
        [43,  1, 57, 53, 56, 58],
        [43,  1, 57, 53, 56,  1],
        [43,  1, 57, 53, 56, 56],
        [43,  1, 57, 53, 52, 43],
        [43,  1, 57, 53, 52, 58],
        [43,  1, 57, 53, 52,  1],
        [43,  1, 57, 53, 52, 56],
        [43,  1, 57, 53, 43, 43],
        [43,  1, 57, 53, 43, 58],
        [43,  1, 57, 53, 43,  1],
        [43,  1, 57, 53, 43, 56],
        [43,  1, 57, 53, 58, 43],
        [43,  1, 57, 53, 58, 58],
        [43,  1, 57, 53, 58,  1],
        [43,  1, 57, 53, 58, 56],
        [43,  1, 57, 43, 56, 43],
        [43,  1, 57, 43, 56, 58],
        [43,  1, 57, 43, 56,  1],
        [43,  1, 57, 43, 56, 56],
        [43,  1, 57, 43, 52, 43],
        [43,  1, 57, 43, 52, 58],
        [43,  1, 57, 43, 52,  1],
        [43,  1, 57, 43, 52, 56],
        [43,  1, 57, 43, 43, 43],
        [43,  1, 57, 43, 43, 58],
        [43

input should be (k^m,t+1+m):  torch.Size([64, 81])
input_cond should be (k^m,t):  torch.Size([64, 81]) 
 tensor([[57, 53, 59, 43,  1,  1],
        [57, 53, 59, 43,  1, 53],
        [57, 53, 59, 43,  1, 43],
        [57, 53, 59, 43,  1, 58],
        [57, 53, 59, 43, 53,  1],
        [57, 53, 59, 43, 53, 53],
        [57, 53, 59, 43, 53, 43],
        [57, 53, 59, 43, 53, 58],
        [57, 53, 59, 43, 43,  1],
        [57, 53, 59, 43, 43, 53],
        [57, 53, 59, 43, 43, 43],
        [57, 53, 59, 43, 43, 58],
        [57, 53, 59, 43, 58,  1],
        [57, 53, 59, 43, 58, 53],
        [57, 53, 59, 43, 58, 43],
        [57, 53, 59, 43, 58, 58],
        [57, 53, 59,  1,  1,  1],
        [57, 53, 59,  1,  1, 53],
        [57, 53, 59,  1,  1, 43],
        [57, 53, 59,  1,  1, 58],
        [57, 53, 59,  1, 53,  1],
        [57, 53, 59,  1, 53, 53],
        [57, 53, 59,  1, 53, 43],
        [57, 53, 59,  1, 53, 58],
        [57, 53, 59,  1, 43,  1],
        [57, 53, 59,  1, 43, 53],
        [57

 tensor([[ 1, 58,  1],
        [ 1, 58, 43],
        [ 1, 58, 53],
        [ 1, 58, 46],
        [ 1, 52,  1],
        [ 1, 52, 43],
        [ 1, 52, 53],
        [ 1, 52, 46],
        [ 1,  1,  1],
        [ 1,  1, 43],
        [ 1,  1, 53],
        [ 1,  1, 46],
        [ 1, 57,  1],
        [ 1, 57, 43],
        [ 1, 57, 53],
        [ 1, 57, 46],
        [43, 58,  1],
        [43, 58, 43],
        [43, 58, 53],
        [43, 58, 46],
        [43, 52,  1],
        [43, 52, 43],
        [43, 52, 53],
        [43, 52, 46],
        [43,  1,  1],
        [43,  1, 43],
        [43,  1, 53],
        [43,  1, 46],
        [43, 57,  1],
        [43, 57, 43],
        [43, 57, 53],
        [43, 57, 46],
        [39, 58,  1],
        [39, 58, 43],
        [39, 58, 53],
        [39, 58, 46],
        [39, 52,  1],
        [39, 52, 43],
        [39, 52, 53],
        [39, 52, 46],
        [39,  1,  1],
        [39,  1, 43],
        [39,  1, 53],
        [39,  1, 46],
        [39, 57,  1],
        [

zero_positions:  torch.Size([64]) tensor([2, 2, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
range_tensor:  torch.Size([64, 3]) tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1,

idx_ntp should be (k^m, t):  torch.Size([64, 89]) 
 tensor([[46, 43,  1, 57, 57, 61],
        [46, 43,  1, 57, 57, 52],
        [46, 43,  1, 57, 57, 56],
        [46, 43,  1, 57, 57, 57],
        [46, 43,  1, 57, 52, 54],
        [46, 43,  1, 57, 52, 56],
        [46, 43,  1, 57, 52, 56],
        [46, 43,  1, 57, 52, 58],
        [46, 43,  1, 57, 53, 59],
        [46, 43,  1, 57, 53, 56],
        [46, 43,  1, 57, 53, 59],
        [46, 43,  1, 57, 53, 53],
        [46, 43,  1, 57, 56, 53],
        [46, 43,  1, 57, 56, 53],
        [46, 43,  1, 57, 56, 56],
        [46, 43,  1, 57, 56, 53],
        [46, 43,  1, 43, 57, 44],
        [46, 43,  1, 43, 57, 52],
        [46, 43,  1, 43, 57, 56],
        [46, 43,  1, 43, 57, 58],
        [46, 43,  1, 43,  1, 52],
        [46, 43,  1, 43,  1,  1],
        [46, 43,  1, 43,  1, 57],
        [46, 43,  1, 43,  1, 57],
        [46, 43,  1, 43, 43, 52],
        [46, 43,  1, 43, 43, 52],
        [46, 43,  1, 43, 43, 50],
        [46, 43,  1, 43, 43, 3

idx_rep should be (k^m,t+1):  torch.Size([64, 88])
input should be (k^m,t+1+m):  torch.Size([64, 91])
input_cond should be (k^m,t):  torch.Size([64, 91]) 
 tensor([[43,  1, 57, 53, 56, 43],
        [43,  1, 57, 53, 56, 58],
        [43,  1, 57, 53, 56, 42],
        [43,  1, 57, 53, 56, 57],
        [43,  1, 57, 53, 52, 43],
        [43,  1, 57, 53, 52, 58],
        [43,  1, 57, 53, 52, 42],
        [43,  1, 57, 53, 52, 57],
        [43,  1, 57, 53, 39, 43],
        [43,  1, 57, 53, 39, 58],
        [43,  1, 57, 53, 39, 42],
        [43,  1, 57, 53, 39, 57],
        [43,  1, 57, 53, 53, 43],
        [43,  1, 57, 53, 53, 58],
        [43,  1, 57, 53, 53, 42],
        [43,  1, 57, 53, 53, 57],
        [43,  1, 57, 39, 56, 43],
        [43,  1, 57, 39, 56, 58],
        [43,  1, 57, 39, 56, 42],
        [43,  1, 57, 39, 56, 57],
        [43,  1, 57, 39, 52, 43],
        [43,  1, 57, 39, 52, 58],
        [43,  1, 57, 39, 52, 42],
        [43,  1, 57, 39, 52, 57],
        [43,  1, 57, 39, 39,

 tensor([[47, 52, 58],
        [47, 52, 43],
        [47, 52, 57],
        [47, 52, 45],
        [47, 41, 58],
        [47, 41, 43],
        [47, 41, 57],
        [47, 41, 45],
        [47, 45, 58],
        [47, 45, 43],
        [47, 45, 57],
        [47, 45, 45],
        [47, 56, 58],
        [47, 56, 43],
        [47, 56, 57],
        [47, 56, 45],
        [39, 52, 58],
        [39, 52, 43],
        [39, 52, 57],
        [39, 52, 45],
        [39, 41, 58],
        [39, 41, 43],
        [39, 41, 57],
        [39, 41, 45],
        [39, 45, 58],
        [39, 45, 43],
        [39, 45, 57],
        [39, 45, 45],
        [39, 56, 58],
        [39, 56, 43],
        [39, 56, 57],
        [39, 56, 45],
        [43, 52, 58],
        [43, 52, 43],
        [43, 52, 57],
        [43, 52, 45],
        [43, 41, 58],
        [43, 41, 43],
        [43, 41, 57],
        [43, 41, 45],
        [43, 45, 58],
        [43, 45, 43],
        [43, 45, 57],
        [43, 45, 45],
        [43, 56, 58],
        [

range_tensor:  torch.Size([64, 3]) tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 

idx_ntp should be (k^m, t):  torch.Size([64, 98]) 
 tensor([[43,  1, 53, 13, 15, 24],
        [43,  1, 53, 13, 15, 44],
        [43,  1, 53, 13, 15, 24],
        [43,  1, 53, 13, 15, 27],
        [43,  1, 53, 13, 53, 58],
        [43,  1, 53, 13, 53, 56],
        [43,  1, 53, 13, 53, 50],
        [43,  1, 53, 13, 53, 43],
        [43,  1, 53, 13, 44, 58],
        [43,  1, 53, 13, 44, 44],
        [43,  1, 53, 13, 44,  1],
        [43,  1, 53, 13, 44, 43],
        [43,  1, 53, 13, 53, 53],
        [43,  1, 53, 13, 53, 44],
        [43,  1, 53, 13, 53, 62],
        [43,  1, 53, 13, 53, 39],
        [43,  1, 53, 53, 13, 58],
        [43,  1, 53, 53, 13, 44],
        [43,  1, 53, 53, 13, 60],
        [43,  1, 53, 53, 13, 39],
        [43,  1, 53, 53, 53, 58],
        [43,  1, 53, 53, 53, 56],
        [43,  1, 53, 53, 53, 50],
        [43,  1, 53, 53, 53, 43],
        [43,  1, 53, 53, 44, 58],
        [43,  1, 53, 53, 44, 44],
        [43,  1, 53, 53, 44,  1],
        [43,  1, 53, 53, 44, 4

idx_rep should be (k^m,t+1):  torch.Size([64, 99])
input should be (k^m,t+1+m):  torch.Size([64, 102])
input_cond should be (k^m,t):  torch.Size([64, 102]) 
 tensor([[44,  1, 58, 46, 43,  1],
        [44,  1, 58, 46, 43, 43],
        [44,  1, 58, 46, 43, 57],
        [44,  1, 58, 46, 43, 56],
        [44,  1, 58, 46, 56,  1],
        [44,  1, 58, 46, 56, 43],
        [44,  1, 58, 46, 56, 57],
        [44,  1, 58, 46, 56, 56],
        [44,  1, 58, 46, 39,  1],
        [44,  1, 58, 46, 39, 43],
        [44,  1, 58, 46, 39, 57],
        [44,  1, 58, 46, 39, 56],
        [44,  1, 58, 46,  1,  1],
        [44,  1, 58, 46,  1, 43],
        [44,  1, 58, 46,  1, 57],
        [44,  1, 58, 46,  1, 56],
        [44,  1, 58, 53, 43,  1],
        [44,  1, 58, 53, 43, 43],
        [44,  1, 58, 53, 43, 57],
        [44,  1, 58, 53, 43, 56],
        [44,  1, 58, 53, 56,  1],
        [44,  1, 58, 53, 56, 43],
        [44,  1, 58, 53, 56, 57],
        [44,  1, 58, 53, 56, 56],
        [44,  1, 58, 53, 3

input should be (k^m,t+1+m):  torch.Size([64, 106])
input_cond should be (k^m,t):  torch.Size([64, 106]) 
 tensor([[43,  1, 41, 53, 56, 43],
        [43,  1, 41, 53, 56, 58],
        [43,  1, 41, 53, 56, 42],
        [43,  1, 41, 53, 56, 57],
        [43,  1, 41, 53, 39, 43],
        [43,  1, 41, 53, 39, 58],
        [43,  1, 41, 53, 39, 42],
        [43,  1, 41, 53, 39, 57],
        [43,  1, 41, 53, 52, 43],
        [43,  1, 41, 53, 52, 58],
        [43,  1, 41, 53, 52, 42],
        [43,  1, 41, 53, 52, 57],
        [43,  1, 41, 53, 53, 43],
        [43,  1, 41, 53, 53, 58],
        [43,  1, 41, 53, 53, 42],
        [43,  1, 41, 53, 53, 57],
        [43,  1, 41, 39, 56, 43],
        [43,  1, 41, 39, 56, 58],
        [43,  1, 41, 39, 56, 42],
        [43,  1, 41, 39, 56, 57],
        [43,  1, 41, 39, 39, 43],
        [43,  1, 41, 39, 39, 58],
        [43,  1, 41, 39, 39, 42],
        [43,  1, 41, 39, 39, 57],
        [43,  1, 41, 39, 52, 43],
        [43,  1, 41, 39, 52, 58],
        [

 tensor([[52, 57,  0],
        [52, 57, 52],
        [52, 57, 56],
        [52, 57, 43],
        [52, 43,  0],
        [52, 43, 52],
        [52, 43, 56],
        [52, 43, 43],
        [52, 58,  0],
        [52, 58, 52],
        [52, 58, 56],
        [52, 58, 43],
        [52, 59,  0],
        [52, 59, 52],
        [52, 59, 56],
        [52, 59, 43],
        [56, 57,  0],
        [56, 57, 52],
        [56, 57, 56],
        [56, 57, 43],
        [56, 43,  0],
        [56, 43, 52],
        [56, 43, 56],
        [56, 43, 43],
        [56, 58,  0],
        [56, 58, 52],
        [56, 58, 56],
        [56, 58, 43],
        [56, 59,  0],
        [56, 59, 52],
        [56, 59, 56],
        [56, 59, 43],
        [57, 57,  0],
        [57, 57, 52],
        [57, 57, 56],
        [57, 57, 43],
        [57, 43,  0],
        [57, 43, 52],
        [57, 43, 56],
        [57, 43, 43],
        [57, 58,  0],
        [57, 58, 52],
        [57, 58, 56],
        [57, 58, 43],
        [57, 59,  0],
        [

zero_positions:  torch.Size([64]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
range_tensor:  torch.Size([64, 3]) tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1,

idx_ntp should be (k^m, t):  torch.Size([64, 111]) 
 tensor([[58, 12,  0,  0, 15, 59],
        [58, 12,  0,  0, 15, 27],
        [58, 12,  0,  0, 15, 44],
        [58, 12,  0,  0, 15, 25],
        [58, 12,  0,  0, 52, 42],
        [58, 12,  0,  0, 52,  6],
        [58, 12,  0,  0, 52, 52],
        [58, 12,  0,  0, 52, 51],
        [58, 12,  0,  0, 46, 53],
        [58, 12,  0,  0, 46, 43],
        [58, 12,  0,  0, 46,  1],
        [58, 12,  0,  0, 46, 58],
        [58, 12,  0,  0, 46, 53],
        [58, 12,  0,  0, 46, 39],
        [58, 12,  0,  0, 46, 59],
        [58, 12,  0,  0, 46, 46],
        [58, 12,  0,  0,  0, 53],
        [58, 12,  0,  0,  0, 43],
        [58, 12,  0,  0,  0, 56],
        [58, 12,  0,  0,  0, 61],
        [58, 12,  0,  0,  0, 42],
        [58, 12,  0,  0,  0,  6],
        [58, 12,  0,  0,  0, 52],
        [58, 12,  0,  0,  0, 61],
        [58, 12,  0,  0, 46, 39],
        [58, 12,  0,  0, 46, 43],
        [58, 12,  0,  0, 46,  1],
        [58, 12,  0,  0, 46, 

 tensor([[33, 30, 21],
        [33, 30, 17],
        [33, 30, 33],
        [33, 30, 32],
        [33, 15, 21],
        [33, 15, 17],
        [33, 15, 33],
        [33, 15, 32],
        [33, 26, 21],
        [33, 26, 17],
        [33, 26, 33],
        [33, 26, 32],
        [33, 56, 21],
        [33, 56, 17],
        [33, 56, 33],
        [33, 56, 32],
        [27, 30, 21],
        [27, 30, 17],
        [27, 30, 33],
        [27, 30, 32],
        [27, 15, 21],
        [27, 15, 17],
        [27, 15, 33],
        [27, 15, 32],
        [27, 26, 21],
        [27, 26, 17],
        [27, 26, 33],
        [27, 26, 32],
        [27, 56, 21],
        [27, 56, 17],
        [27, 56, 33],
        [27, 56, 32],
        [13, 30, 21],
        [13, 30, 17],
        [13, 30, 33],
        [13, 30, 32],
        [13, 15, 21],
        [13, 15, 17],
        [13, 15, 33],
        [13, 15, 32],
        [13, 26, 21],
        [13, 26, 17],
        [13, 26, 33],
        [13, 26, 32],
        [13, 56, 21],
        [

range_tensor:  torch.Size([64, 3]) tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 

idx_ntp should be (k^m, t):  torch.Size([64, 118]) 
 tensor([[21, 27, 24, 13, 26, 33],
        [21, 27, 24, 13, 26, 17],
        [21, 27, 24, 13, 26, 24],
        [21, 27, 24, 13, 26,  0],
        [21, 27, 24, 13, 24, 33],
        [21, 27, 24, 13, 24, 17],
        [21, 27, 24, 13, 24, 24],
        [21, 27, 24, 13, 24,  0],
        [21, 27, 24, 13, 13, 33],
        [21, 27, 24, 13, 13, 17],
        [21, 27, 24, 13, 13, 26],
        [21, 27, 24, 13, 13,  0],
        [21, 27, 24, 13,  0, 39],
        [21, 27, 24, 13,  0, 39],
        [21, 27, 24, 13,  0,  6],
        [21, 27, 24, 13,  0,  0],
        [21, 27, 24, 13, 24, 33],
        [21, 27, 24, 13, 24, 10],
        [21, 27, 24, 13, 24, 10],
        [21, 27, 24, 13, 24,  0],
        [21, 27, 24, 13, 24, 21],
        [21, 27, 24, 13, 24, 17],
        [21, 27, 24, 13, 24, 24],
        [21, 27, 24, 13, 24,  0],
        [21, 27, 24, 13, 13, 33],
        [21, 27, 24, 13, 13, 17],
        [21, 27, 24, 13, 13, 10],
        [21, 27, 24, 13, 13, 

input should be (k^m,t+1+m):  torch.Size([64, 122])
input_cond should be (k^m,t):  torch.Size([64, 122]) 
 tensor([[13, 26, 33, 31, 10,  0],
        [13, 26, 33, 31, 10, 10],
        [13, 26, 33, 31, 10, 26],
        [13, 26, 33, 31, 10, 32],
        [13, 26, 33, 31,  8,  0],
        [13, 26, 33, 31,  8, 10],
        [13, 26, 33, 31,  8, 26],
        [13, 26, 33, 31,  8, 32],
        [13, 26, 33, 31,  6,  0],
        [13, 26, 33, 31,  6, 10],
        [13, 26, 33, 31,  6, 26],
        [13, 26, 33, 31,  6, 32],
        [13, 26, 33, 31, 11,  0],
        [13, 26, 33, 31, 11, 10],
        [13, 26, 33, 31, 11, 26],
        [13, 26, 33, 31, 11, 32],
        [13, 26, 33, 24, 10,  0],
        [13, 26, 33, 24, 10, 10],
        [13, 26, 33, 24, 10, 26],
        [13, 26, 33, 24, 10, 32],
        [13, 26, 33, 24,  8,  0],
        [13, 26, 33, 24,  8, 10],
        [13, 26, 33, 24,  8, 26],
        [13, 26, 33, 24,  8, 32],
        [13, 26, 33, 24,  6,  0],
        [13, 26, 33, 24,  6, 10],
        [

 tensor([[10,  0, 21, 53,  1,  1],
        [10,  0, 21, 53,  1, 58],
        [10,  0, 21, 53,  1, 43],
        [10,  0, 21, 53,  1, 50],
        [10,  0, 21, 53, 39,  1],
        [10,  0, 21, 53, 39, 58],
        [10,  0, 21, 53, 39, 43],
        [10,  0, 21, 53, 39, 50],
        [10,  0, 21, 53, 43,  1],
        [10,  0, 21, 53, 43, 58],
        [10,  0, 21, 53, 43, 43],
        [10,  0, 21, 53, 43, 50],
        [10,  0, 21, 53, 58,  1],
        [10,  0, 21, 53, 58, 58],
        [10,  0, 21, 53, 58, 43],
        [10,  0, 21, 53, 58, 50],
        [10,  0, 21, 46,  1,  1],
        [10,  0, 21, 46,  1, 58],
        [10,  0, 21, 46,  1, 43],
        [10,  0, 21, 46,  1, 50],
        [10,  0, 21, 46, 39,  1],
        [10,  0, 21, 46, 39, 58],
        [10,  0, 21, 46, 39, 43],
        [10,  0, 21, 46, 39, 50],
        [10,  0, 21, 46, 43,  1],
        [10,  0, 21, 46, 43, 58],
        [10,  0, 21, 46, 43, 43],
        [10,  0, 21, 46, 43, 50],
        [10,  0, 21, 46, 58,  1],
        [10, 

idx should be (1,t+max_val+1):  torch.Size([1, 125]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46]],
       device='mps:0')
mlogits should be (m,v):  torch.Size([3, 65])
idx_m_topk should be (m,k):  torch.Size([3, 4])
mcomb should be (k^m,m):  torch.Size([64, 3]) 
 tensor([[39,  1, 43],
        [39,  1, 50],
        [39,  1,  1],
        [39,  1, 52],
        [39, 39, 43],
        [39, 39, 50],
        [39, 39,  1],
        [39, 39, 52],
        [39, 60, 43],
        [39, 60, 50

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(55, device='mps:0')
max_idx:  tensor(8, device='mps:0')
idx:  torch.Size([1, 125]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46]],
       device='mps:0')
idx_m should be (1,max_va

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0,

 tensor([[58,  1, 57, 53,  1, 43],
        [58,  1, 57, 53,  1,  1],
        [58,  1, 57, 53,  1, 58],
        [58,  1, 57, 53,  1, 56],
        [58,  1, 57, 53, 43, 43],
        [58,  1, 57, 53, 43,  1],
        [58,  1, 57, 53, 43, 58],
        [58,  1, 57, 53, 43, 56],
        [58,  1, 57, 53, 52, 43],
        [58,  1, 57, 53, 52,  1],
        [58,  1, 57, 53, 52, 58],
        [58,  1, 57, 53, 52, 56],
        [58,  1, 57, 53, 58, 43],
        [58,  1, 57, 53, 58,  1],
        [58,  1, 57, 53, 58, 58],
        [58,  1, 57, 53, 58, 56],
        [58,  1, 57, 43,  1, 43],
        [58,  1, 57, 43,  1,  1],
        [58,  1, 57, 43,  1, 58],
        [58,  1, 57, 43,  1, 56],
        [58,  1, 57, 43, 43, 43],
        [58,  1, 57, 43, 43,  1],
        [58,  1, 57, 43, 43, 58],
        [58,  1, 57, 43, 43, 56],
        [58,  1, 57, 43, 52, 43],
        [58,  1, 57, 43, 52,  1],
        [58,  1, 57, 43, 52, 58],
        [58,  1, 57, 43, 52, 56],
        [58,  1, 57, 43, 58, 43],
        [58, 

result:  tensor([2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(2, device='mps:0')
reversed_max_idx:  tensor(60, device='mps:0')
max_idx:  tensor(3, device='mps:0')
idx:  torch.Size([1, 134]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57]], dev

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

max_val:  tensor(2, device='mps:0')
reversed_max_idx:  tensor(52, device='mps:0')
max_idx:  tensor(11, device='mps:0')
idx:  torch.Size([1, 140]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57]],
       device='mps:0')
idx_m should be (1,max_val):  torch.Size([1, 2]) tensor([[53,  1]], device='mps:0')
idx_ntp should be (1,1):  torch.Size([1, 1]) tensor([[57]], device='mps:0')
idx should be (1,t+max_val+1):  torch

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(2, device='mps:0')
reversed_max_idx:  tensor(52, device='mps:0')
max_idx:  tensor(11, device='mps:0')
idx:  torch.Size([1, 143]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53, 

result:  tensor([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(2, device='mps:0')
reversed_max_idx:  tensor(56, device='mps:0')
max_idx:  tensor(7, device='mps:0')
idx:  torch.Size([1, 146]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0,

 tensor([[43,  1, 58, 53,  1, 43],
        [43,  1, 58, 53,  1,  1],
        [43,  1, 58, 53,  1, 58],
        [43,  1, 58, 53,  1, 56],
        [43,  1, 58, 53, 52, 43],
        [43,  1, 58, 53, 52,  1],
        [43,  1, 58, 53, 52, 58],
        [43,  1, 58, 53, 52, 56],
        [43,  1, 58, 53, 58, 43],
        [43,  1, 58, 53, 58,  1],
        [43,  1, 58, 53, 58, 58],
        [43,  1, 58, 53, 58, 56],
        [43,  1, 58, 53, 43, 43],
        [43,  1, 58, 53, 43,  1],
        [43,  1, 58, 53, 43, 58],
        [43,  1, 58, 53, 43, 56],
        [43,  1, 58, 43,  1, 43],
        [43,  1, 58, 43,  1,  1],
        [43,  1, 58, 43,  1, 58],
        [43,  1, 58, 43,  1, 56],
        [43,  1, 58, 43, 52, 43],
        [43,  1, 58, 43, 52,  1],
        [43,  1, 58, 43, 52, 58],
        [43,  1, 58, 43, 52, 56],
        [43,  1, 58, 43, 58, 43],
        [43,  1, 58, 43, 58,  1],
        [43,  1, 58, 43, 58, 58],
        [43,  1, 58, 43, 58, 56],
        [43,  1, 58, 43, 43, 43],
        [43, 

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True,  True],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(1, device='mps:0')
reversed_max_idx:  tensor(48, device='mps:0')
max_idx:  tensor(15, device='mps:0')
idx:  torch.Size([1, 166]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53, 

result:  tensor([1, 1, 1, 1, 2, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(58, device='mps:0')
max_idx:  tensor(5, device='mps:0')
idx:  torch.Size([1, 168]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 0, 15, 27, 30, 21, 27],
        [ 0, 15, 27, 30, 21, 24],
        [ 0, 15, 27, 30, 21, 32],
        [ 0, 15, 27, 30, 21, 21],
        [ 0, 15, 27, 30, 13, 26],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 47],
        [ 0, 15, 27, 30, 58, 27],
        [ 0, 15, 27, 30, 58, 26],
        [ 0, 15, 27, 30, 58, 32],
        [ 0, 15, 27, 30, 58, 43],
        [ 0, 15, 27, 30, 21, 33],
        [ 0, 15, 27, 30, 21, 26],
        [ 0, 15, 27, 30, 21, 31],
        [ 0, 15, 27, 30, 21, 58],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 39],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 57, 27],
        [ 0, 15, 27, 32, 57, 31],
        [ 0, 15, 27, 32, 57, 32],
        [ 0, 15, 27, 32, 57, 43],
        [ 0, 15, 27, 32, 21, 33],
        [ 0, 

idx should be (1,t+max_val+1):  torch.Size([1, 179]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27]],
       device='mps:0')
mlogits should be (m,v):  torch.Size([3, 65])
idx_m_topk should be (m,k):  torch.Size([3, 4])
mcomb should be (k^m,m

result:  tensor([3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(63, device='mps:0')
max_idx:  tensor(0, device='mps:0')
idx:  torch.Size([1, 179]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

result:  tensor([3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(63, device='mps:0')
max_idx:  tensor(0, device='mps:0')
idx:  torch.Size([1, 183]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 61, 47, 57, 57, 52],
        [ 1, 61, 47, 57, 57, 47],
        [ 1, 61, 47, 57, 57, 61],
        [ 1, 61, 47, 57, 57, 53],
        [ 1, 61, 47, 57, 57,  1],
        [ 1, 61, 47, 57, 57, 50],
        [ 1, 61, 47, 57, 57, 58],
        [ 1, 61, 47, 57, 57, 58],
        [ 1, 61, 47, 57, 43,  1],
        [ 1, 61, 47, 57, 43,  1],
        [ 1, 61, 47, 57, 43, 58],
        [ 1, 61, 47, 57, 43,  1],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 57, 50, 58],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 59, 57, 52],
        [ 1, 61, 47, 59, 57, 47],
        [ 1, 61, 47, 59, 57, 57],
        [ 1, 61, 47, 59, 57, 53],
        [ 1, 61, 47, 59, 63,  1],
        [ 1, 61, 47, 59, 63, 50],
        [ 1, 61, 47, 59, 63, 57],
        [ 1, 61, 47, 59, 63,  1],
        [ 1, 61, 47, 59, 43,  1],
        [ 1, 61, 47, 59, 43, 43],
        [ 1, 61, 47, 59, 43, 58],
        [ 1, 61, 47, 59, 43, 

idx should be (1,t+max_val+1):  torch.Size([1, 193]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1]], device='mps:0')
mlogits should be (m,v):  torch.Size([3, 65])
idx_m_topk s

result:  tensor([2, 3, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(62, device='mps:0')
max_idx:  tensor(1, device='mps:0')
idx:  torch.Size([1, 193]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 57, 53,  1, 58, 52],
        [ 1, 57, 53,  1, 58, 58],
        [ 1, 57, 53,  1, 58, 46],
        [ 1, 57, 53,  1, 58, 43],
        [ 1, 57, 53,  1, 46, 42],
        [ 1, 57, 53,  1, 46, 58],
        [ 1, 57, 53,  1, 46, 43],
        [ 1, 57, 53,  1, 46, 39],
        [ 1, 57, 53,  1, 45,  8],
        [ 1, 57, 53,  1, 45, 58],
        [ 1, 57, 53,  1, 45, 43],
        [ 1, 57, 53,  1, 45, 53],
        [ 1, 57, 53,  1,  1, 42],
        [ 1, 57, 53,  1,  1, 58],
        [ 1, 57, 53,  1,  1, 46],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53, 52, 58, 52],
        [ 1, 57, 53, 52, 58, 58],
        [ 1, 57, 53, 52, 58, 53],
        [ 1, 57, 53, 52, 58, 43],
        [ 1, 57, 53, 52, 43, 56],
        [ 1, 57, 53, 52, 43, 58],
        [ 1, 57, 53, 52, 43, 43],
        [ 1, 57, 53, 52, 43, 39],
        [ 1, 57, 53, 52, 42, 42],
        [ 1, 57, 53, 52, 42, 58],
        [ 1, 57, 53, 52, 42, 43],
        [ 1, 57, 53, 52, 42, 

idx:  torch.Size([1, 201]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58]], device='mps:0')
idx_m should be (1,max_val):  torch.Size([1

 tensor([[1, 1, 0],
        [1, 1, 0],
        [1, 1, 1],
        [1, 1, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 1],
        [0, 1, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0,

 tensor([[ 1, 57, 43, 56, 43, 43],
        [ 1, 57, 43, 56, 43, 58],
        [ 1, 57, 43, 56, 43,  1],
        [ 1, 57, 43, 56, 43, 50],
        [ 1, 57, 43, 56, 56, 43],
        [ 1, 57, 43, 56, 56, 58],
        [ 1, 57, 43, 56, 56,  1],
        [ 1, 57, 43, 56, 56, 50],
        [ 1, 57, 43, 56, 50, 43],
        [ 1, 57, 43, 56, 50, 58],
        [ 1, 57, 43, 56, 50,  1],
        [ 1, 57, 43, 56, 50, 50],
        [ 1, 57, 43, 56, 58, 43],
        [ 1, 57, 43, 56, 58, 58],
        [ 1, 57, 43, 56, 58,  1],
        [ 1, 57, 43, 56, 58, 50],
        [ 1, 57, 43, 52, 43, 43],
        [ 1, 57, 43, 52, 43, 58],
        [ 1, 57, 43, 52, 43,  1],
        [ 1, 57, 43, 52, 43, 50],
        [ 1, 57, 43, 52, 56, 43],
        [ 1, 57, 43, 52, 56, 58],
        [ 1, 57, 43, 52, 56,  1],
        [ 1, 57, 43, 52, 56, 50],
        [ 1, 57, 43, 52, 50, 43],
        [ 1, 57, 43, 52, 50, 58],
        [ 1, 57, 43, 52, 50,  1],
        [ 1, 57, 43, 52, 50, 50],
        [ 1, 57, 43, 52, 58, 43],
        [ 1, 

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(1, device='mps:0')
reversed_max_idx:  tensor(48, device='mps:0')
max_idx:  tensor(15, device='mps:0')
idx:  torch.Size([1, 209]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53, 

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[43,  1, 53, 32, 18, 24],
        [43,  1, 53, 32, 18, 44],
        [43,  1, 53, 32, 18, 47],
        [43,  1, 53, 32, 18, 24],
        [43,  1, 53, 32, 53, 58],
        [43,  1, 53, 32, 53, 56],
        [43,  1, 53, 32, 53, 43],
        [43,  1, 53, 32, 53, 39],
        [43,  1, 53, 32, 44, 58],
        [43,  1, 53, 32, 44, 44],
        [43,  1, 53, 32, 44, 43],
        [43,  1, 53, 32, 44,  1],
        [43,  1, 53, 32, 53, 58],
        [43,  1, 53, 32, 53, 58],
        [43,  1, 53, 32, 53, 43],
        [43,  1, 53, 32, 53,  5],
        [43,  1, 53, 53, 32, 30],
        [43,  1, 53, 53, 32, 44],
        [43,  1, 53, 53, 32, 39],
        [43,  1, 53, 53, 32, 60],
        [43,  1, 53, 53, 53, 58],
        [43,  1, 53, 53, 53, 56],
        [43,  1, 53, 53, 53, 43],
        [43,  1, 53, 53, 53, 50],
        [43,  1, 53, 53, 44, 58],
        [43,  1, 53, 53, 44, 44],
        [43,  1, 53, 53, 44, 43],
        [43,  1, 53, 53, 44,  1],
        [43,  1, 53, 53, 53, 53],
        [43, 

idx should be (1,t+max_val+1):  torch.Size([1, 219]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

result:  tensor([3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(63, device='mps:0')
max_idx:  tensor(0, device='mps:0')
idx:  torch.Size([1, 219]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45, 57],
        [56, 60, 47, 41, 45,  8],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43, 47],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 52, 58,  1],
        [56, 60, 47, 52, 58,  8],
        [56, 60, 47, 52, 58,  8],
        [56, 60, 47, 52, 58, 43],
        [56, 60, 47, 52, 43, 47],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43, 43],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43, 43],
        [56, 60, 47, 52, 43, 47],
        [56, 

idx should be (1,t+max_val+1):  torch.Size([1, 230]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 0, 15, 27, 30, 21, 27],
        [ 0, 15, 27, 30, 21, 24],
        [ 0, 15, 27, 30, 21, 32],
        [ 0, 15, 27, 30, 21, 21],
        [ 0, 15, 27, 30, 13, 26],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 47],
        [ 0, 15, 27, 30, 58, 27],
        [ 0, 15, 27, 30, 58, 26],
        [ 0, 15, 27, 30, 58, 32],
        [ 0, 15, 27, 30, 58, 43],
        [ 0, 15, 27, 30, 21, 33],
        [ 0, 15, 27, 30, 21, 26],
        [ 0, 15, 27, 30, 21, 31],
        [ 0, 15, 27, 30, 21, 58],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 39],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 57, 27],
        [ 0, 15, 27, 32, 57, 31],
        [ 0, 15, 27, 32, 57, 32],
        [ 0, 15, 27, 32, 57, 43],
        [ 0, 15, 27, 32, 21, 33],
        [ 0, 

result:  tensor([3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(63, device='mps:0')
max_idx:  tensor(0, device='mps:0')
idx:  torch.Size([1, 233]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

idx_check should be (k^m,m):  torch.Size([64, 3]) 
 tensor([[24, 13, 26],
        [24, 13, 26],
        [24, 13, 26],
        [24, 13, 26],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13,  0],
        [24, 13,  0],
        [24, 13,  0],
        [24, 13,  0],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 24],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13, 31],
        [24, 13,  0],
        [24, 13,  0],
        [24, 13,  0],
        [24, 13,  0],
        [24, 33, 25],
        [24, 33, 25],
        [24, 33, 25],
        [24, 33, 25],
        [24, 33, 24],
        [24, 33, 24],
        [24, 33, 24],
        [24, 33, 24],
        [24, 33, 31],
        [24, 33, 31],
        [24, 33, 31],
  

idx_rep should be (k^m,t+1):  torch.Size([64, 241])
input should be (k^m,t+1+m):  torch.Size([64, 244])
input_cond should be (k^m,t):  torch.Size([64, 128]) 
 tensor([[13, 26, 33, 31, 10,  0],
        [13, 26, 33, 31, 10, 10],
        [13, 26, 33, 31, 10, 26],
        [13, 26, 33, 31, 10, 32],
        [13, 26, 33, 31,  8,  0],
        [13, 26, 33, 31,  8, 10],
        [13, 26, 33, 31,  8, 26],
        [13, 26, 33, 31,  8, 32],
        [13, 26, 33, 31,  6,  0],
        [13, 26, 33, 31,  6, 10],
        [13, 26, 33, 31,  6, 26],
        [13, 26, 33, 31,  6, 32],
        [13, 26, 33, 31, 11,  0],
        [13, 26, 33, 31, 11, 10],
        [13, 26, 33, 31, 11, 26],
        [13, 26, 33, 31, 11, 32],
        [13, 26, 33, 24, 10,  0],
        [13, 26, 33, 24, 10, 10],
        [13, 26, 33, 24, 10, 26],
        [13, 26, 33, 24, 10, 32],
        [13, 26, 33, 24,  8,  0],
        [13, 26, 33, 24,  8, 10],
        [13, 26, 33, 24,  8, 26],
        [13, 26, 33, 24,  8, 32],
        [13, 26, 33, 24, 

idx_m should be (1,max_val):  torch.Size([1, 3]) tensor([[31, 10,  0]], device='mps:0')
idx_ntp should be (1,1):  torch.Size([1, 1]) tensor([[21]], device='mps:0')
idx should be (1,t+max_val+1):  torch.Size([1, 245]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 1, 46, 39, 60, 57, 52],
        [ 1, 46, 39, 60, 57, 47],
        [ 1, 46, 39, 60, 57, 57],
        [ 1, 46, 39, 60, 57, 53],
        [ 1, 46, 39, 60, 60,  1],
        [ 1, 46, 39, 60, 60, 50],
        [ 1, 46, 39, 60, 60, 57],
        [ 1, 46, 39, 60, 60, 45],
        [ 1, 46, 39, 60, 43,  1],
        [ 1, 46, 39, 60, 43, 43],
        [ 1, 46, 39, 60, 43, 57],
        [ 1, 46, 39, 60, 43, 45],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 60, 50, 52],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 59, 57, 52],
        [ 1, 46, 39, 59, 57, 47],
        [ 1, 46, 39, 59, 57, 57],
        [ 1, 46, 39, 59, 57, 53],
        [ 1, 46, 39, 59, 63,  1],
        [ 1, 46, 39, 59, 63, 50],
        [ 1, 46, 39, 59, 63, 57],
        [ 1, 46, 39, 59, 63,  1],
        [ 1, 46, 39, 59, 43,  1],
        [ 1, 46, 39, 59, 43, 43],
        [ 1, 46, 39, 59, 43, 40],
        [ 1, 46, 39, 59, 43,  1],
        [ 1, 46, 39, 59, 43, 57],
        [ 1, 

idx:  torch.Size([1, 247]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43,  1, 53,
 

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 57, 53, 58, 43, 58],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58,  1, 57],
        [ 1, 57, 53, 58,  1, 42],
        [ 1, 57, 53, 58,  1, 39],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 58, 51, 51],
        [ 1, 57, 53, 58, 51, 52],
        [ 1, 57, 53, 58, 51, 43],
        [ 1, 57, 53, 58, 51, 43],
        [ 1, 57, 53, 58,  1, 58],
        [ 1, 57, 53, 58,  1, 42],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 61, 42, 58],
        [ 1, 57, 53, 61, 42, 51],
        [ 1, 57, 53, 61, 42,  1],
        [ 1, 57, 53, 61, 42, 43],
        [ 1, 57, 53, 61,  1, 58],
        [ 1, 57, 53, 61,  1, 56],
        [ 1, 57, 53, 61,  1, 39],
        [ 1, 57, 53, 61,  1, 43],
        [ 1, 57, 53, 61, 57, 57],
        [ 1, 57, 53, 61, 57, 52],
        [ 1, 57, 53, 61, 57, 43],
        [ 1, 57, 53, 61, 57, 

idx should be (1,t+max_val+1):  torch.Size([1, 256]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 57, 53,  1, 56,  1],
        [ 1, 57, 53,  1, 56, 58],
        [ 1, 57, 53,  1, 56,  1],
        [ 1, 57, 53,  1, 56, 53],
        [ 1, 57, 53,  1, 57, 52],
        [ 1, 57, 53,  1, 57, 57],
        [ 1, 57, 53,  1, 57, 46],
        [ 1, 57, 53,  1, 57, 43],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 58],
        [ 1, 57, 53,  1,  1, 47],
        [ 1, 57, 53,  1,  1, 43],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 58],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 60],
        [ 1, 57, 53, 43, 60,  1],
        [ 1, 57, 53, 43, 60, 58],
        [ 1, 57, 53, 43, 60,  1],
        [ 1, 57, 53, 43, 60, 39],
        [ 1, 57, 53, 43, 58, 52],
        [ 1, 57, 53, 43, 58, 58],
        [ 1, 57, 53, 43, 58, 53],
        [ 1, 57, 53, 43, 58, 43],
        [ 1, 57, 53, 43,  1, 52],
        [ 1, 57, 53, 43,  1, 58],
        [ 1, 57, 53, 43,  1, 43],
        [ 1, 57, 53, 43,  1, 

idx:  torch.Size([1, 259]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43,  1, 53,
 

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 57, 53,  1, 56,  1],
        [ 1, 57, 53,  1, 56, 39],
        [ 1, 57, 53,  1, 56,  1],
        [ 1, 57, 53,  1, 56, 53],
        [ 1, 57, 53,  1, 46, 42],
        [ 1, 57, 53,  1, 46, 58],
        [ 1, 57, 53,  1, 46, 47],
        [ 1, 57, 53,  1, 46, 43],
        [ 1, 57, 53,  1, 57, 52],
        [ 1, 57, 53,  1, 57, 58],
        [ 1, 57, 53,  1, 57, 46],
        [ 1, 57, 53,  1, 57, 43],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 58],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 60],
        [ 1, 57, 53, 43, 60,  1],
        [ 1, 57, 53, 43, 60, 58],
        [ 1, 57, 53, 43, 60,  1],
        [ 1, 57, 53, 43, 60, 39],
        [ 1, 57, 53, 43,  1, 52],
        [ 1, 57, 53, 43,  1, 58],
        [ 1, 57, 53, 43,  1, 43],
        [ 1, 57, 53, 43,  1, 39],
        [ 1, 57, 53, 43, 58, 52],
        [ 1, 57, 53, 43, 58, 58],
        [ 1, 57, 53, 43, 58, 53],
        [ 1, 57, 53, 43, 58, 

idx should be (1,t+max_val+1):  torch.Size([1, 268]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[ 1, 57, 46,  1, 56,  1],
        [ 1, 57, 46,  1, 56, 39],
        [ 1, 57, 46,  1, 56,  1],
        [ 1, 57, 46,  1, 56, 53],
        [ 1, 57, 46,  1, 57, 52],
        [ 1, 57, 46,  1, 57, 58],
        [ 1, 57, 46,  1, 57, 46],
        [ 1, 57, 46,  1, 57, 43],
        [ 1, 57, 46,  1, 46, 52],
        [ 1, 57, 46,  1, 46, 58],
        [ 1, 57, 46,  1, 46, 47],
        [ 1, 57, 46,  1, 46, 43],
        [ 1, 57, 46,  1, 52,  1],
        [ 1, 57, 46,  1, 52, 58],
        [ 1, 57, 46,  1, 52,  1],
        [ 1, 57, 46,  1, 52,  1],
        [ 1, 57, 46, 56, 60,  1],
        [ 1, 57, 46, 56, 60, 58],
        [ 1, 57, 46, 56, 60,  1],
        [ 1, 57, 46, 56, 60, 39],
        [ 1, 57, 46, 56, 58, 52],
        [ 1, 57, 46, 56, 58, 58],
        [ 1, 57, 46, 56, 58, 53],
        [ 1, 57, 46, 56, 58, 43],
        [ 1, 57, 46, 56,  1, 52],
        [ 1, 57, 46, 56,  1, 58],
        [ 1, 57, 46, 56,  1, 43],
        [ 1, 57, 46, 56,  1, 

idx:  torch.Size([1, 271]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43,  1, 53,
 

idx_check should be (k^m,m):  torch.Size([64, 3]) 
 tensor([[39, 50,  1],
        [39, 50,  1],
        [39, 50,  1],
        [39, 50,  1],
        [39, 50, 57],
        [39, 50, 57],
        [39, 50, 57],
        [39, 50, 57],
        [39, 50, 50],
        [39, 50, 50],
        [39, 50, 50],
        [39, 50, 50],
        [39, 50, 43],
        [39, 50, 43],
        [39, 50, 43],
        [39, 50, 43],
        [39,  1, 54],
        [39,  1, 54],
        [39,  1, 54],
        [39,  1, 54],
        [39,  1, 61],
        [39,  1, 61],
        [39,  1, 61],
        [39,  1, 61],
        [39,  1, 44],
        [39,  1, 44],
        [39,  1, 44],
        [39,  1, 44],
        [39,  1, 43],
        [39,  1, 43],
        [39,  1, 43],
        [39,  1, 43],
        [39, 47, 61],
        [39, 47, 61],
        [39, 47, 61],
        [39, 47, 61],
        [39, 47, 39],
        [39, 47, 39],
        [39, 47, 39],
        [39, 47, 39],
        [39, 47, 39],
        [39, 47, 39],
        [39, 47, 39],
  

idx should be (1,t+max_val+1):  torch.Size([1, 276]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 1, 57, 53,  1, 57, 52],
        [ 1, 57, 53,  1, 57, 57],
        [ 1, 57, 53,  1, 57, 53],
        [ 1, 57, 53,  1, 57, 43],
        [ 1, 57, 53,  1,  8,  8],
        [ 1, 57, 53,  1,  8, 53],
        [ 1, 57, 53,  1,  8,  1],
        [ 1, 57, 53,  1,  8,  5],
        [ 1, 57, 53,  1, 43, 42],
        [ 1, 57, 53,  1, 43, 58],
        [ 1, 57, 53,  1, 43, 43],
        [ 1, 57, 53,  1, 43, 43],
        [ 1, 57, 53,  1,  1, 42],
        [ 1, 57, 53,  1,  1, 58],
        [ 1, 57, 53,  1,  1,  1],
        [ 1, 57, 53,  1,  1, 60],
        [ 1, 57, 53, 56, 39, 52],
        [ 1, 57, 53, 56, 39, 58],
        [ 1, 57, 53, 56, 39, 53],
        [ 1, 57, 53, 56, 39, 43],
        [ 1, 57, 53, 56, 58, 42],
        [ 1, 57, 53, 56, 58, 58],
        [ 1, 57, 53, 56, 58,  1],
        [ 1, 57, 53, 56, 58,  5],
        [ 1, 57, 53, 56,  1, 52],
        [ 1, 57, 53, 56,  1, 58],
        [ 1, 57, 53, 56,  1, 43],
        [ 1, 57, 53, 56,  1, 43],
        [ 1, 57, 53, 56, 52, 52],
        [ 1, 

idx should be (1,t+max_val+1):  torch.Size([1, 283]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[53,  8, 57, 32, 18, 18],
        [53,  8, 57, 32, 18, 24],
        [53,  8, 57, 32, 18, 44],
        [53,  8, 57, 32, 18, 57],
        [53,  8, 57, 32, 58, 18],
        [53,  8, 57, 32, 58, 58],
        [53,  8, 57, 32, 58, 44],
        [53,  8, 57, 32, 58, 62],
        [53,  8, 57, 32, 44, 32],
        [53,  8, 57, 32, 44, 51],
        [53,  8, 57, 32, 44, 44],
        [53,  8, 57, 32, 44, 42],
        [53,  8, 57, 32, 52, 32],
        [53,  8, 57, 32, 52, 58],
        [53,  8, 57, 32, 52, 54],
        [53,  8, 57, 32, 52, 60],
        [53,  8, 57, 52, 32, 18],
        [53,  8, 57, 52, 32, 58],
        [53,  8, 57, 52, 32, 44],
        [53,  8, 57, 52, 32, 52],
        [53,  8, 57, 52, 58, 32],
        [53,  8, 57, 52, 58, 58],
        [53,  8, 57, 52, 58, 59],
        [53,  8, 57, 52, 58, 52],
        [53,  8, 57, 52, 54, 32],
        [53,  8, 57, 52, 54, 58],
        [53,  8, 57, 52, 54, 54],
        [53,  8, 57, 52, 54, 57],
        [53,  8, 57, 52, 56, 32],
        [53, 

idx should be (1,t+max_val+1):  torch.Size([1, 286]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(1, device='mps:0')
reversed_max_idx:  tensor(48, device='mps:0')
max_idx:  tensor(15, device='mps:0')
idx:  torch.Size([1, 286]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53, 

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 0, 15, 27, 30, 21, 27],
        [ 0, 15, 27, 30, 21, 24],
        [ 0, 15, 27, 30, 21, 32],
        [ 0, 15, 27, 30, 21, 21],
        [ 0, 15, 27, 30, 58, 27],
        [ 0, 15, 27, 30, 58, 26],
        [ 0, 15, 27, 30, 58, 32],
        [ 0, 15, 27, 30, 58, 43],
        [ 0, 15, 27, 30, 13, 26],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 47],
        [ 0, 15, 27, 30, 21, 26],
        [ 0, 15, 27, 30, 21, 26],
        [ 0, 15, 27, 30, 21, 25],
        [ 0, 15, 27, 30, 21, 57],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 39],
        [ 0, 15, 27, 32, 57, 27],
        [ 0, 15, 27, 32, 57, 31],
        [ 0, 15, 27, 32, 57, 32],
        [ 0, 15, 27, 32, 57, 43],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 24],
        [ 0, 

idx should be (1,t+max_val+1):  torch.Size([1, 295]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 39],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 53],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 39],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 

idx should be (1,t+max_val+1):  torch.Size([1, 303]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 1, 61, 47, 57, 57, 52],
        [ 1, 61, 47, 57, 57, 47],
        [ 1, 61, 47, 57, 57, 61],
        [ 1, 61, 47, 57, 57, 53],
        [ 1, 61, 47, 57, 57,  1],
        [ 1, 61, 47, 57, 57, 50],
        [ 1, 61, 47, 57, 57, 58],
        [ 1, 61, 47, 57, 57, 58],
        [ 1, 61, 47, 57, 43,  1],
        [ 1, 61, 47, 57, 43, 43],
        [ 1, 61, 47, 57, 43, 58],
        [ 1, 61, 47, 57, 43,  1],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 57, 50, 58],
        [ 1, 61, 47, 57, 50,  1],
        [ 1, 61, 47, 59, 57, 52],
        [ 1, 61, 47, 59, 57, 47],
        [ 1, 61, 47, 59, 57, 57],
        [ 1, 61, 47, 59, 57, 53],
        [ 1, 61, 47, 59, 63,  1],
        [ 1, 61, 47, 59, 63, 50],
        [ 1, 61, 47, 59, 63, 52],
        [ 1, 61, 47, 59, 63,  1],
        [ 1, 61, 47, 59, 43, 42],
        [ 1, 61, 47, 59, 43, 43],
        [ 1, 61, 47, 59, 43, 58],
        [ 1, 61, 47, 59, 43,  1],
        [ 1, 61, 47, 59, 43,  1],
        [ 1, 

idx should be (1,t+max_val+1):  torch.Size([1, 309]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[58,  1, 40, 43,  1, 57],
        [58,  1, 40, 43,  1,  1],
        [58,  1, 40, 43,  1,  1],
        [58,  1, 40, 43,  1, 56],
        [58,  1, 40, 43, 43, 58],
        [58,  1, 40, 43, 43,  1],
        [58,  1, 40, 43, 43,  1],
        [58,  1, 40, 43, 43, 60],
        [58,  1, 40, 43, 63, 58],
        [58,  1, 40, 43, 63,  1],
        [58,  1, 40, 43, 63, 47],
        [58,  1, 40, 43, 63,  1],
        [58,  1, 40, 43, 43, 40],
        [58,  1, 40, 43, 43,  1],
        [58,  1, 40, 43, 43,  1],
        [58,  1, 40, 43, 43, 47],
        [58,  1, 40, 46,  1, 58],
        [58,  1, 40, 46,  1,  1],
        [58,  1, 40, 46,  1, 46],
        [58,  1, 40, 46,  1, 58],
        [58,  1, 40, 46, 50, 58],
        [58,  1, 40, 46, 50, 42],
        [58,  1, 40, 46, 50, 43],
        [58,  1, 40, 46, 50, 56],
        [58,  1, 40, 46, 49, 58],
        [58,  1, 40, 46, 49,  1],
        [58,  1, 40, 46, 49, 43],
        [58,  1, 40, 46, 49, 56],
        [58,  1, 40, 46, 43, 57],
        [58, 

idx should be (1,t+max_val+1):  torch.Size([1, 317]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[58, 46, 43,  1, 57, 57],
        [58, 46, 43,  1, 57, 52],
        [58, 46, 43,  1, 57, 43],
        [58, 46, 43,  1, 57, 43],
        [58, 46, 43,  1,  1, 58],
        [58, 46, 43,  1,  1,  1],
        [58, 46, 43,  1,  1,  1],
        [58, 46, 43,  1,  1,  1],
        [58, 46, 43,  1, 43, 57],
        [58, 46, 43,  1, 43,  1],
        [58, 46, 43,  1, 43, 43],
        [58, 46, 43,  1, 43, 39],
        [58, 46, 43,  1,  1, 58],
        [58, 46, 43,  1,  1, 56],
        [58, 46, 43,  1,  1,  1],
        [58, 46, 43,  1,  1, 43],
        [58, 46, 43, 57, 57, 57],
        [58, 46, 43, 57, 57, 52],
        [58, 46, 43, 57, 57, 43],
        [58, 46, 43, 57, 57, 43],
        [58, 46, 43, 57, 44, 58],
        [58, 46, 43, 57, 44, 42],
        [58, 46, 43, 57, 44,  1],
        [58, 46, 43, 57, 44,  1],
        [58, 46, 43, 57,  1, 57],
        [58, 46, 43, 57,  1, 50],
        [58, 46, 43, 57,  1,  1],
        [58, 46, 43, 57,  1, 43],
        [58, 46, 43, 57, 46, 58],
        [58, 

idx should be (1,t+max_val+1):  torch.Size([1, 325]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

mask:  torch.Size([64, 3]) tensor([[ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43, 47],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43, 42],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 52, 58,  1],
        [56, 60, 47, 52, 58,  1],
        [56, 60, 47, 52, 58, 43],
        [56, 60, 47, 52, 58,  1],
        [56, 60, 47, 52, 43,  1],
        [56, 60, 47, 52, 43,  1],
        [56, 60, 47, 52, 43, 43],
        [56, 60, 47, 52, 43,  1],
        [56, 60, 47, 52, 43, 47],
        [56, 60, 47, 52, 43,  1],
        [56, 60, 47, 52, 43, 43],
        [56, 60, 47, 52, 43,  1],
        [56, 60, 47, 52, 43,  1],
        [56, 

idx should be (1,t+max_val+1):  torch.Size([1, 331]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

idx:  torch.Size([1, 331]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43,  1, 53,
 

 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 1, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0,

idx_rep should be (k^m,t+1):  torch.Size([64, 339])
input should be (k^m,t+1+m):  torch.Size([64, 342])
input_cond should be (k^m,t):  torch.Size([64, 128]) 
 tensor([[43,  1, 57, 53, 56, 43],
        [43,  1, 57, 53, 56, 58],
        [43,  1, 57, 53, 56, 57],
        [43,  1, 57, 53, 56, 56],
        [43,  1, 57, 53, 39, 43],
        [43,  1, 57, 53, 39, 58],
        [43,  1, 57, 53, 39, 57],
        [43,  1, 57, 53, 39, 56],
        [43,  1, 57, 53, 52, 43],
        [43,  1, 57, 53, 52, 58],
        [43,  1, 57, 53, 52, 57],
        [43,  1, 57, 53, 52, 56],
        [43,  1, 57, 53, 53, 43],
        [43,  1, 57, 53, 53, 58],
        [43,  1, 57, 53, 53, 57],
        [43,  1, 57, 53, 53, 56],
        [43,  1, 57, 39, 56, 43],
        [43,  1, 57, 39, 56, 58],
        [43,  1, 57, 39, 56, 57],
        [43,  1, 57, 39, 56, 56],
        [43,  1, 57, 39, 39, 43],
        [43,  1, 57, 39, 39, 58],
        [43,  1, 57, 39, 39, 57],
        [43,  1, 57, 39, 39, 56],
        [43,  1, 57, 39, 

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

idx_ntp should be (k^m, t):  torch.Size([64, 128]) 
 tensor([[56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 45,  8],
        [56, 60, 47, 41, 45,  1],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43,  1],
        [56, 60, 47, 41, 43, 47],
        [56, 60, 47, 41, 43,  8],
        [56, 60, 47, 41, 43, 43],
        [56, 60, 47, 52, 58, 57],
        [56, 60, 47, 52, 58,  1],
        [56, 60, 47, 52, 58,  8],
        [56, 60, 47, 52, 58, 43],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43, 47],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43, 43],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43,  8],
        [56, 60, 47, 52, 43, 

idx:  torch.Size([1, 342]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56, 60, 47, 41, 43,  1, 53,
 

mask:  torch.Size([64, 3]) tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 0, 15, 27, 30, 21, 27],
        [ 0, 15, 27, 30, 21, 24],
        [ 0, 15, 27, 30, 21, 32],
        [ 0, 15, 27, 30, 21, 21],
        [ 0, 15, 27, 30, 13, 33],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 31],
        [ 0, 15, 27, 30, 13, 47],
        [ 0, 15, 27, 30, 58, 27],
        [ 0, 15, 27, 30, 58, 26],
        [ 0, 15, 27, 30, 58, 32],
        [ 0, 15, 27, 30, 58, 43],
        [ 0, 15, 27, 30, 21, 33],
        [ 0, 15, 27, 30, 21, 26],
        [ 0, 15, 27, 30, 21, 31],
        [ 0, 15, 27, 30, 21, 58],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 39],
        [ 0, 15, 27, 32, 21, 27],
        [ 0, 15, 27, 32, 21, 26],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 21, 32],
        [ 0, 15, 27, 32, 57, 27],
        [ 0, 15, 27, 32, 57, 31],
        [ 0, 15, 27, 32, 57, 32],
        [ 0, 15, 27, 32, 57, 43],
        [ 0, 15, 27, 32, 21, 33],
        [ 0, 

idx_rep should be (k^m,t+1):  torch.Size([64, 353])
input should be (k^m,t+1+m):  torch.Size([64, 356])
input_cond should be (k^m,t):  torch.Size([64, 128]) 
 tensor([[30, 21, 27, 24, 13, 26],
        [30, 21, 27, 24, 13, 30],
        [30, 21, 27, 24, 13, 31],
        [30, 21, 27, 24, 13, 27],
        [30, 21, 27, 24, 27, 26],
        [30, 21, 27, 24, 27, 30],
        [30, 21, 27, 24, 27, 31],
        [30, 21, 27, 24, 27, 27],
        [30, 21, 27, 24, 10, 26],
        [30, 21, 27, 24, 10, 30],
        [30, 21, 27, 24, 10, 31],
        [30, 21, 27, 24, 10, 27],
        [30, 21, 27, 24, 24, 26],
        [30, 21, 27, 24, 24, 30],
        [30, 21, 27, 24, 24, 31],
        [30, 21, 27, 24, 24, 27],
        [30, 21, 27, 31, 13, 26],
        [30, 21, 27, 31, 13, 30],
        [30, 21, 27, 31, 13, 31],
        [30, 21, 27, 31, 13, 27],
        [30, 21, 27, 31, 27, 26],
        [30, 21, 27, 31, 27, 30],
        [30, 21, 27, 31, 27, 31],
        [30, 21, 27, 31, 27, 27],
        [30, 21, 27, 31, 

mask:  torch.Size([64, 3]) tensor([[ True,  True,  True],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 39],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 13],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 10,  0, 21],
        [26, 33, 31, 10,  0,  0],
        [26, 33, 31, 10,  0, 53],
        [26, 33, 31, 10,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 39],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0,  0],
        [26, 33, 31, 13,  0, 21],
        [26, 33, 31, 13,  0, 46],
        [26, 33, 31, 13,  0, 21],
        [26, 

result:  tensor([3, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(63, device='mps:0')
max_idx:  tensor(0, device='mps:0')
idx:  torch.Size([1, 357]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 1, 46, 39, 60, 57, 52],
        [ 1, 46, 39, 60, 57, 47],
        [ 1, 46, 39, 60, 57, 57],
        [ 1, 46, 39, 60, 57, 53],
        [ 1, 46, 39, 60, 60,  1],
        [ 1, 46, 39, 60, 60, 50],
        [ 1, 46, 39, 60, 60, 57],
        [ 1, 46, 39, 60, 60, 42],
        [ 1, 46, 39, 60, 43,  1],
        [ 1, 46, 39, 60, 43, 43],
        [ 1, 46, 39, 60, 43, 52],
        [ 1, 46, 39, 60, 43, 45],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 60, 50, 52],
        [ 1, 46, 39, 60, 50,  1],
        [ 1, 46, 39, 59, 57, 52],
        [ 1, 46, 39, 59, 57, 47],
        [ 1, 46, 39, 59, 57, 57],
        [ 1, 46, 39, 59, 57, 53],
        [ 1, 46, 39, 59, 63,  1],
        [ 1, 46, 39, 59, 63, 50],
        [ 1, 46, 39, 59, 63, 57],
        [ 1, 46, 39, 59, 63,  1],
        [ 1, 46, 39, 59, 43,  1],
        [ 1, 46, 39, 59, 43, 43],
        [ 1, 46, 39, 59, 43, 40],
        [ 1, 46, 39, 59, 43,  1],
        [ 1, 46, 39, 59, 43, 57],
        [ 1, 

result:  tensor([1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0')
max_val:  tensor(3, device='mps:0')
reversed_max_idx:  tensor(55, device='mps:0')
max_idx:  tensor(8, device='mps:0')
idx:  torch.Size([1, 363]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  

mask:  torch.Size([64, 3]) tensor([[False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [False, False, False],
        [Fal

 tensor([[ 1, 57, 53, 58, 43, 58],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58, 43,  1],
        [ 1, 57, 53, 58,  1, 57],
        [ 1, 57, 53, 58,  1, 42],
        [ 1, 57, 53, 58,  1, 39],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 58, 51, 51],
        [ 1, 57, 53, 58, 51, 52],
        [ 1, 57, 53, 58, 51, 43],
        [ 1, 57, 53, 58, 51, 43],
        [ 1, 57, 53, 58,  1, 58],
        [ 1, 57, 53, 58,  1, 42],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 58,  1,  1],
        [ 1, 57, 53, 61, 42, 58],
        [ 1, 57, 53, 61, 42, 51],
        [ 1, 57, 53, 61, 42,  1],
        [ 1, 57, 53, 61, 42, 43],
        [ 1, 57, 53, 61,  1, 58],
        [ 1, 57, 53, 61,  1, 56],
        [ 1, 57, 53, 61,  1, 39],
        [ 1, 57, 53, 61,  1, 43],
        [ 1, 57, 53, 61, 57, 57],
        [ 1, 57, 53, 61, 57, 52],
        [ 1, 57, 53, 61, 57, 43],
        [ 1, 57, 53, 61, 57, 53],
        [ 1, 57, 53, 61, 42, 58],
        [ 1, 

idx should be (1,t+max_val+1):  torch.Size([1, 372]) tensor([[22, 33, 24, 21, 17, 32, 10,  0, 27,  1, 30, 53, 51, 43, 53,  6,  1, 30,
         53, 51, 43, 53,  2,  1, 61, 46, 43, 56, 43, 44, 53, 56, 43,  1, 39, 56,
         58,  1, 58, 46, 53, 59,  1, 30, 53, 51, 43, 53, 12,  0,  0, 15, 27, 30,
         21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39, 60, 43,  1, 57, 53,
         51, 43,  1, 57, 53, 59, 50,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,
         60, 47, 41, 43,  1, 53, 44,  1, 58, 46, 43,  1, 41, 53, 59, 56, 58, 12,
          0,  0, 15, 27, 30, 21, 27, 24, 13, 26, 33, 31, 10,  0, 21,  1, 46, 39,
         60, 43,  1, 52, 53, 58,  1, 57, 53,  1, 57, 53,  1, 57, 53,  1, 57, 53,
          1, 57, 53,  1, 57, 46, 39, 50, 50,  1, 40, 43,  1, 58, 53,  1, 58, 46,
         43,  1, 57, 43, 56, 60, 47, 41, 43,  8,  0,  0, 15, 27, 30, 21, 27, 24,
         13, 26, 33, 31, 10,  0, 21,  1, 61, 47, 50, 50,  1, 52, 53, 58,  1, 40,
         43,  1, 58, 53,  1, 58, 46, 43,  1, 57, 43, 56,

In [116]:
print("tokens per inference: ", sum(tok_per_inf)/len(tok_per_inf))
print(decode(output[0].tolist()))

tokens per inference:  3.28
JULIET:
O Romeo, Romeo! wherefore art thou Romeo?

CORIOLANUS:
I have some soul to the service of the court?

CORIOLANUS:
I have not so so so so so shall be to the service.

CORIOLANUS:
I will not be to the service of the service.

CORIOLANUS:
I have not so so so so so shall be so so so.

CORIOLANUS:
I will not be to the service of the service.

CORIOLANUS:
I have not s


# the problem is there's still not enough incentive to veer away from repetition because the thing i'm checking against is greedy decoding. maybe on the check piece i can implement a sample instead of argmax? or use topk then select the lowest of the topk? 