In [1]:
import sys
import os
# Add the project root to the path so we can import from src
sys.path.append(os.path.abspath(".."))

import torch
from torch.utils.data import DataLoader
from src.datasets.shakespeare.shakespeare import ShakespeareDataset
from src.tokenizers.character_level.character_level import CharacterLevelTokenizer

## Configuration
Define the hyperparameters for the dataset and dataloader.

In [2]:
BATCH_SIZE = 128
MAX_LENGTH = 256
MIN_T = 1e-6
NUM_WORKERS = 0 

## Tokenizer and Dataset
Initialize the character-level tokenizer and the Shakespeare dataset.

In [3]:
tokenizer = CharacterLevelTokenizer()
dataset = ShakespeareDataset(tokenizer=tokenizer, max_length=MAX_LENGTH, min_t=MIN_T, train=True)

print(f"Dataset size: {len(dataset)}")
print(f"Vocab size: {tokenizer.vocab_size()}")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Dataset size: 1003854
Vocab size: 35


## Custom Collate Function
Define a collate function that packs the batch into a single sequence and returns document IDs.

In [4]:
def packed_collate_fn(batch):
    """
    Collate function to pack a batch of sequences into a single sequence.
    
    Args:
        batch: List of dictionaries, each containing 'x' (sequence) and 't' (timestep).
        
    Returns:
        Dictionary containing:
        - 'x': Packed sequence tensor of shape (total_seq_len,)
        - 'doc_ids': Document ID tensor of shape (total_seq_len,)
        - 't': Packed timestep tensor of shape (total_seq_len,)
    """
    xs = [item['x'] for item in batch]
    ts = [item['t'] for item in batch]
    
    # Pack x into a single sequence
    packed_x = torch.cat(xs, dim=0) # (total_seq_len,)
    
    # Create document ids
    doc_ids = []
    for i, x in enumerate(xs):
        doc_ids.append(torch.full_like(x, i))
    packed_doc_ids = torch.cat(doc_ids, dim=0) # (total_seq_len,)
    
    # Pack t (expand t for each token in the sequence)
    # t is (1,) per sample. We need to repeat it for len(x)
    packed_ts = []
    for i, (x, t) in enumerate(zip(xs, ts)):
        packed_ts.append(t.repeat(len(x)))
    packed_t = torch.cat(packed_ts, dim=0)

    return {
        "x": packed_x,
        "doc_ids": packed_doc_ids,
        "t": packed_t
    }

## DataLoader
Create the DataLoader using the custom collate function.

In [5]:
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=packed_collate_fn,
    num_workers=NUM_WORKERS
)

## Verification
Run a single batch to verify the shapes and content.

In [6]:
batch = next(iter(dataloader))
print("Batch keys:", batch.keys())
print("Packed x shape:", batch['x'].shape)
print("Doc ids shape:", batch['doc_ids'].shape)
print("Packed t shape:", batch['t'].shape)

print("\nSample check:")
print("First 10 tokens:", batch['x'][:10])
print("First 10 doc ids:", batch['doc_ids'][:10])
print("First 10 t values:", batch['t'][:10])

Batch keys: dict_keys(['x', 'doc_ids', 't'])
Packed x shape: torch.Size([32768])
Doc ids shape: torch.Size([32768])
Packed t shape: torch.Size([32768])

Sample check:
First 10 tokens: tensor([10,  8, 11, 11, 32, 12, 24, 18,  4, 11])
First 10 doc ids: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
First 10 t values: tensor([0.8537, 0.8537, 0.8537, 0.8537, 0.8537, 0.8537, 0.8537, 0.8537, 0.8537,
        0.8537])


## Model Components
Import necessary libraries and define the `PackDynamicSequenceChunker` and Flex Attention utilities.

In [7]:
from collections import namedtuple
import torch
from torch import cat, arange
from torch.nested import nested_tensor
from torch.nn import Module, Linear, Parameter
from torch.nn.functional import cosine_similarity, pad, softmax
from torch.nn.utils.rnn import pad_sequence
from einx import multiply
from einops import repeat, rearrange
from mamba_ssm import Mamba2
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from assoc_scan import AssocScan
# following section 2.2 of the paper

from collections import namedtuple

import torch
from torch import Tensor
from torch import cat, arange
from torch.nested import nested_tensor
from torch.nn import Module, Linear, Parameter
from torch.nn.functional import cosine_similarity, pad

from einx import multiply
from einops import repeat, rearrange

from assoc_scan import AssocScan

# constants

Outputs = namedtuple('Outputs', [
    'downsampled',
    'upsample_fn',
    'weighted_aux_ratio_loss'
])

Intermediates = namedtuple('Intermediates', [
    'mask',
    'probs',
    'chunk_lens',
    'boundary_mask',
    'residual',
    'gates',
    'upsampler_output_scale',
    'aux_ratio_loss',
    'new_seq_lens'
])

# helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def straight_through(t, value):
    return t + (value - t).detach()

def frac_gradient(t, frac = 1.):
    if frac == 1:
        return

    t_grad = t * frac
    return straight_through(t_grad, t)

# classes

class PackDynamicSequenceChunker(Module):
    def __init__(
        self,
        dim,
        dim_queries_keys = None,
        boundary_threshold = 0.5,
        target_avg_token_length = 6.,       # N in eq(10)
        ratio_loss_weight = 3e-2,
        handle_residual_proj = False,       # turning this on will automatically handle a projection of the residual and its application in the inverse upsample function
        assoc_scan_use_accelerated = False,
        learning_rate_difference = 0.75,    # in the paper, they report that as one moves up a hierarchy, the learning rate needs to decrease. we'll default to 0.75 for the rough 2.0 -> 1.5 somewhere in the appendix from level 0 -> 1
        straight_through_frac_vecs = True,  # improvisation where F receives gradients through straight-through with sigmoid
    ):
        super().__init__()
        dim_queries_keys = default(dim_queries_keys, dim)

        # linear to queries and keys

        self.to_queries_keys = Linear(dim, dim_queries_keys * 2, bias = False)

        # start key token, so first token can be segmented / chunked out

        self.start_key_token = Parameter(torch.randn(dim_queries_keys) * 1e-2) # presumably, need a start key token for the first token, open an issue if i got it wrong

        # threshold to determine boundary

        assert 0. < boundary_threshold < 1.

        self.boundary_threshold = boundary_threshold

        # smoothing related

        self.smooth_assoc_scan = AssocScan(use_accelerated = assoc_scan_use_accelerated)

        # maybe residual proj

        self.handle_residual_proj = handle_residual_proj

        if handle_residual_proj:
            self.residual_proj = Linear(dim, dim)

        # learning rate modulation, appendix C
        # the multiplier on the learning rate as one goes from outer to inner of the h-net, and inverse of this value from inner to outer

        self.learning_rate_difference = learning_rate_difference

        # ratio aux loss related

        self.target_avg_token_length = target_avg_token_length

        self.straight_through_frac_vecs = straight_through_frac_vecs

        self.ratio_loss_weight = ratio_loss_weight

        self.register_buffer('zero', torch.tensor(0.), persistent = False)

    def upsample(
        self,
        downsampled,
        intermediates: Intermediates,
        apply_scale = True
    ):
        batch, needs_grad, device = downsampled.shape[0], downsampled.requires_grad, downsampled.device

        mask = intermediates.mask
        gates = intermediates.gates
        residual = intermediates.residual

        # smoothing module for improved gradients eq(5)

        downsampled = self.smooth_assoc_scan(gates, downsampled)

        # upsample

        downsampled_without_padding = downsampled[mask]
        chunk_lens_without_padding = intermediates.chunk_lens[mask]

        seq = arange(downsampled_without_padding.shape[0], device = device)

        repeated_indices = torch.repeat_interleave(seq, chunk_lens_without_padding, dim = 0)
        upsampled = downsampled_without_padding[repeated_indices]

        upsampled = rearrange(upsampled, '(b n) d -> b n d', b = batch)

        scale = intermediates.upsampler_output_scale

        if needs_grad and apply_scale and exists(scale):
            upsampled = multiply('b n d, b n', upsampled, scale)

        if self.handle_residual_proj:
            upsampled = upsampled + self.residual_proj(residual)

        upsampled = frac_gradient(upsampled, self.learning_rate_difference)

        return upsampled

    def forward(
        self,
        tokens, # float[b n d] or float[total_n d] if seq_lens is specified,
        seq_lens: Tensor | None = None,
        return_intermediates = False,
        return_only_chunk_lens = False
    ):
        with torch.no_grad():
            if seq_lens is not None:
                total_lens = seq_lens.sum().item()
                document_ids = torch.repeat_interleave(
                    torch.arange(len(seq_lens), device=seq_lens.device), seq_lens
                )

                # a sequence position with 1 in probs_mask is the position of the first
                # token of a new document, which means it must be a chunk start with
                # probability 1
                packed_probs_mask = torch.zeros_like(document_ids)
                packed_probs_mask[1:] = document_ids[:-1] != document_ids[1:]

                # however, since the sequence position is the start of a new document,
                # we must prevent the associative scan from reading from the token before 
                # it. To do this, we reverse probs_mask, so the sequence position that used
                # to be 1 becomes 0 and the positions that used to be 0 become 1.
                # this means that at the start of each new document, the token cannot
                # read from the token before it
                packed_gate_mask = -1 * (packed_probs_mask - 1)
                tokens = tokens.unsqueeze(0)
            else:
                packed_probs_mask = None
                packed_gate_mask = None
                document_ids = None

        batch, length, device = *tokens.shape[:2], tokens.device

        residual = tokens

        queries, keys = self.to_queries_keys(tokens).chunk(2, dim = -1)

        start_keys = repeat(self.start_key_token, 'd -> b 1 d', b = batch)

        keys = cat((start_keys, keys), dim = 1)

        if packed_probs_mask is not None:
            # when packed, the keys end up being compared incorrectly at this current stage
            # for example, suppose we have two documents of lengths 2 and 2.
            # if passed individually, each document's first token will compare against the start key token
            # however, when packed, the 3rd token (first token of second document)
            # will compare against the key of the 2nd token, resulting in a wrong cosine_similarity
            # which later impacts the probability
            # at first I thought this would be fine because we hard set the probability, however
            # now I recall that in the associative scan smoothing, this probability term is involved
            # beyond the gate itself, which would result in an incorrect calculation, so
            # we need to make all those keys that are at the start of a new document
            # equal to the start key token
            
            # first, we start by adding a 1 to the right side of the packed_probs_mask, this is to account
            # for the fact that when calculating cosine similarity, we use `keys[:, :-1]`, so it is shifted
            # so the placement of the start key token needs to be shifted as well
            packed_probs_mask_with_start = pad(packed_probs_mask, (0, 1), value = 0)

            # and now, for all sequence positions where packed_probs_mask_with_start is 1,
            # we set the corresponding keys to the start key token
            keys[:, packed_probs_mask_with_start == 1] = start_keys


        # each query looks at the previous key to determine if distance is greater than some threshold for determining a boundary exists (they use 0.5 as threshold)

        cosine_sim  = cosine_similarity(queries, keys[:, :-1], dim = -1)

        probs = (1. - cosine_sim) * 0.5 # cosine sim is -1. to 1., this transforms it to 0. to 1.

        boundary_mask = probs > self.boundary_threshold # bool[b n]

        boundary_mask[:, 0] = True # first token must always be boundary

        if packed_probs_mask is not None:
            # at all positions where the packed_probs_masking is 1, it means it is the start
            # of a new document. We must force these positions to be boundaries
            # previously I tried doing it by setting probs to 1, but that
            # will cause issues later down the line because downsampling tensor is multiplied
            # by the probs, so we must directly set the boundary mask instead
            boundary_mask = torch.where(packed_probs_mask == 1, True, boundary_mask)

        # compute some lengths, per chunk and number of chunks per batch

        num_chunks = boundary_mask.long().sum(dim = -1)

        boundary_mask_with_end = pad(boundary_mask, (0, 1), value = True)
        sel_indices = repeat(arange(boundary_mask_with_end.shape[-1], device = device), 'n -> b n', b = batch)[boundary_mask_with_end]

        sel_indices = nested_tensor(sel_indices.split((num_chunks + 1).tolist()), layout = torch.jagged, device = device)

        sel_indices = sel_indices.to_padded_tensor(padding = -1)

        mask = (sel_indices != -1)[:, 1:]

        chunk_lens = sel_indices[:, 1:] - sel_indices[:, :-1]
        chunk_lens.masked_fill_(~mask, 0)

        # early return chunk lens if using a trained module as a tokenizer

        if return_only_chunk_lens:
            return chunk_lens

        # downsampling - they show in their experiments that picking out the boundary tokens works just fine

        boundary_tokens = tokens[boundary_mask] # pick out boundary tokens

        tokens_nt = nested_tensor(boundary_tokens.split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        downsampled_tokens = tokens_nt.to_padded_tensor(padding = 0.)

        # smoothing module for improved gradients eq(5)

        probs_nt = nested_tensor(probs[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        boundary_probs = probs_nt.to_padded_tensor(padding = 0.)

        gates = 1. - boundary_probs

        if packed_gate_mask is not None:
            # at all positions where the packed_gate_masking is 0, it means it is the start
            # of a new document. We must prevent associative scan from allowing
            # this starting token from reading into the past document
            # also, gradients cannot propagate through this to modify this gating, as it is
            # fixed by the document sequence
            packed_gate_mask_nt = nested_tensor(packed_gate_mask.unsqueeze(0)[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = False)
            packed_gate_masking = packed_gate_mask_nt.to_padded_tensor(padding = 1.0)
            gates = gates * packed_gate_masking

        downsampled_tokens = multiply('b n d, b n', downsampled_tokens, boundary_probs)


        # for the upsampler

        confidence = torch.where(boundary_mask, probs, 1. - probs)

        # defaults if not training

        upsampler_output_scale = None
        aux_loss = self.zero
        weighted_aux_loss = self.zero

        needs_grad = tokens.requires_grad

        if needs_grad:
            # straight through for 1. multiplier on the expanded processed boundary tokens

            upsampler_output_scale = straight_through(confidence, 1.)

            # auxiliary ratio loss in section 2.3.2, eq (10)
            # lets follow their notation

            N = self.target_avg_token_length

            F = boundary_mask.float()
            G = probs.mean(dim = -1)

            # allow for a soft F to straight through - https://arxiv.org/abs/2505.22074

            if self.straight_through_frac_vecs:
                F_soft = (probs - self.boundary_threshold).sigmoid()
                F = straight_through(F_soft, F)

            F = F.mean(dim = -1)

            aux_ratio_loss = N / (N - 1) * ((N - 1) * F * G + (1. - F) * (1. - G))

            aux_loss = aux_ratio_loss.mean()
            weighted_aux_loss = aux_loss * self.ratio_loss_weight

        # intermediates
        if document_ids is not None:
            # this minlength should not be necessary as the boundaries should 
            # guarantee that each document has at least one chunk
            new_seq_lens = torch.bincount(document_ids, weights=boundary_mask.squeeze(0).long(), minlength=len(seq_lens)).long()
        else:
            new_seq_lens = num_chunks

        intermediates = Intermediates(mask, probs, chunk_lens, boundary_mask, residual, gates, upsampler_output_scale, aux_loss, new_seq_lens)

        # return the upsample function

        def upsample(downsampled, apply_scale = True):
            downsampled_input = downsampled.unsqueeze(0) if downsampled.ndim == 2 else downsampled
            upsampled = self.upsample(downsampled_input, intermediates, apply_scale = apply_scale)
            return upsampled.squeeze(0) if downsampled.ndim == 2 else upsampled

        # adjust learning rate

        downsampled_tokens = frac_gradient(downsampled_tokens, self.learning_rate_difference ** -1)

        if packed_probs_mask is not None:
            downsampled_tokens = downsampled_tokens.squeeze(0)

        # returning

        outputs = Outputs(downsampled_tokens, upsample, weighted_aux_loss)

        if not return_intermediates:
            return outputs

        return outputs, intermediates


    
# --- Flex Attention Utils ---

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def generate_doc_mask_mod(mask_mod, document_id):
    # can feed in another mask modifier function such as `causal`

    # Get unique document IDs and their counts
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # Create cumulative counts (offsets)
    offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
    def doc_mask_wrapper(b, h, q_idx, kv_idx):
        same_doc = document_id[q_idx] == document_id[kv_idx]
        q_logical = q_idx - offsets[document_id[q_idx]]
        kv_logical = kv_idx - offsets[document_id[kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask
    return doc_mask_wrapper


## Model Definition
Define the `TransformerBlock` using Flex Attention and the `HybridModel` combining Mamba, H-Net Chunker, and Transformer.

In [8]:
class TransformerBlock(Module):
    def __init__(self, dim, heads, dim_head):
        super().__init__()
        self.norm1 = torch.nn.RMSNorm(dim)
        self.heads = heads
        self.dim_head = dim_head
        
        self.to_qkv = Linear(dim, heads * dim_head * 3, bias=False)
        self.to_out = Linear(heads * dim_head, dim, bias=False)
        
        self.norm2 = torch.nn.RMSNorm(dim)
        self.ff = torch.nn.Sequential(
            Linear(dim, dim * 4),
            torch.nn.GELU(),
            Linear(dim * 4, dim)
        )
        self.flex = torch.compile(flex_attention)

    def forward(self, x, block_mask):
        # x: (1, SeqLen, Dim) - treating packed as batch 1
        B, S, D = x.shape
        
        residual = x
        x = self.norm1(x)
        
        qkv = self.to_qkv(x) # (B, S, 3 * H * Dh)
        q, k, v = rearrange(qkv, 'b s (t h d) -> t b h s d', t=3, h=self.heads, d=self.dim_head)
        
        # Flex Attention
        out = self.flex(q, k, v, block_mask=block_mask) # (B, S, H, D)
        
        out = rearrange(out, 'b h s d -> b s (h d)')
        out = self.to_out(out)
        
        x = residual + out
        
        residual = x
        x = self.norm2(x)
        x = self.ff(x)
        x = residual + x
        
        return x

class HybridModel(Module):
    def __init__(self, dim, vocab_size, depth=4):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, dim)
        
        # Pre-chunk Mamba
        self.mamba_pre = Mamba2(
            d_model=dim,
            headdim=dim//16 if dim >= 16 else 4,
            d_state=16,
            d_conv=4,
            expand=2
        )
        
        # Chunker
        self.chunker = PackDynamicSequenceChunker(dim=dim)
        
        # Main Transformer (Flex Attention)
        self.transformer_blocks = torch.nn.ModuleList([
            TransformerBlock(dim, heads=4, dim_head=dim//4)
            for _ in range(depth)
        ])
        
        # Post-chunk Mamba
        self.mamba_post = Mamba2(
            d_model=dim,
            headdim=dim//16 if dim >= 16 else 4,
            d_state=16,
            d_conv=4,
            expand=2
        )
        
        self.head = Linear(dim, vocab_size)

    def forward(self, x, doc_ids):
        # x: (Total_L,)
        # doc_ids: (Total_L,)
        
        total_l = x.shape[0]

        # Embedding
        x = self.embedding(x) # (Total_L, Dim)
        
        # Mamba Pre
        # Treat as batch 1, but use seq_idx for packed logic
        x_unsqueezed = x.unsqueeze(0) # (1, Total_L, Dim)
        seq_idx = doc_ids.unsqueeze(0).int() # (1, Total_L)
        
        x = self.mamba_pre(x_unsqueezed, seq_idx=seq_idx) # (1, Total_L, Dim)
        x = x.squeeze(0) # (Total_L, Dim)
        
        # Chunker Downsample
        # Need seq_lens
        unique_doc_ids, counts = torch.unique_consecutive(doc_ids, return_counts=True)
        seq_lens = counts
        
        outputs, intermediates = self.chunker(x, seq_lens=seq_lens, return_intermediates=True)
        x_down = outputs.downsampled # (Total_Chunks, Dim)
        
        # Transformer
        # Reconstruct doc_ids for downsampled sequence
        with torch.no_grad():
            new_seq_lens = intermediates.new_seq_lens # (Batch_Size,)
            doc_ids_down = torch.repeat_interleave(unique_doc_ids, new_seq_lens)
        
        # Create mask for Flex Attention
        x_down_unsqueezed = x_down.unsqueeze(0) # (1, Total_Chunks, Dim)
        
        # Generate mask
        mask_mod = generate_doc_mask_mod(causal, doc_ids_down)
        block_mask = create_block_mask(mask_mod, B=None, H=None, Q_LEN=x_down_unsqueezed.shape[1], KV_LEN=x_down_unsqueezed.shape[1], device=x.device)
        
        for block in self.transformer_blocks:
            x_down_unsqueezed = block(x_down_unsqueezed, block_mask)
            
        x_down = x_down_unsqueezed.squeeze(0)
        
        # Chunker Upsample
        x_up = outputs.upsample_fn(x_down) # (Total_L, Dim)
        
        # Mamba Post
        x_up_unsqueezed = x_up.unsqueeze(0)
        x_final = self.mamba_post(x_up_unsqueezed, seq_idx=seq_idx)
        x_final = x_final.squeeze(0)
        
        # Head
        logits = self.head(x_final)
        return logits, outputs.weighted_aux_ratio_loss


## Model Verification
Instantiate the model and run a forward pass with the sample batch.

In [9]:
# Hyperparameters
DIM = 64
VOCAB_SIZE = tokenizer.vocab_size()
DEPTH = 2

# Instantiate Model
model = HybridModel(dim=DIM, vocab_size=VOCAB_SIZE, depth=DEPTH)
model = model.to('cuda')

# Get batch from previous step
x = batch['x'].to('cuda')
doc_ids = batch['doc_ids'].to('cuda')

print(f"Input shape: {x.shape}")
print(f"Doc IDs shape: {doc_ids.shape}")

Input shape: torch.Size([32768])
Doc IDs shape: torch.Size([32768])


In [10]:
# Forward Pass
logits, aux_loss = model(x, doc_ids)

print(f"Logits shape: {logits.shape}")
print(f"Aux Loss: {aux_loss.item()}")


  return _C._get_float32_matmul_precision()


Logits shape: torch.Size([32768, 35])
Aux Loss: 0.055008966475725174


In [11]:
aux_loss

tensor(0.0550, device='cuda:0', grad_fn=<MulBackward0>)

In [12]:
aux_loss.backward()

## Training Loop
Set up the optimizer and the training loop. The loop iterates through the dataloader, handles document boundaries for the loss, and updates the model.

In [None]:
from tqdm.auto import tqdm
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss(reduction='none') # We need to mask manually

# Training Loop
NUM_STEPS = 10000
PRINT_EVERY = 1000

model.train()
iter_dataloader = iter(dataloader)

print("Starting training...")

for step in tqdm(range(NUM_STEPS)):
    try:
        batch = next(iter_dataloader)
    except StopIteration:
        iter_dataloader = iter(dataloader)
        batch = next(iter_dataloader)
        
    x = batch['x'].to('cuda')
    doc_ids = batch['doc_ids'].to('cuda')
    
    # Prepare inputs and targets
    # We predict the next token
    input_ids = x[:-1]
    target_ids = x[1:]
    input_doc_ids = doc_ids[:-1]
    target_doc_ids = doc_ids[1:]
    
    # Forward pass
    logits, aux_loss = model(input_ids, input_doc_ids)
    
    # Calculate Loss
    ce_loss = criterion(logits, target_ids)
    
    # Mask loss at document boundaries
    # If input_doc_ids[i] != target_doc_ids[i], it means target is from a new doc
    # We shouldn't predict across documents
    valid_mask = (input_doc_ids == target_doc_ids).float()
    
    masked_ce_loss = (ce_loss * valid_mask).sum() / valid_mask.sum()
    
    total_loss = masked_ce_loss + aux_loss
    
    # Backward
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % PRINT_EVERY == 0:
        print(f"Step {step}: Loss {total_loss.item():.4f} (CE: {masked_ce_loss.item():.4f}, Aux: {aux_loss.item():.4f})")


Starting training...


  0%|          | 0/10000 [00:00<?, ?it/s]

Step 0: Loss 3.6436 (CE: 3.5889, Aux: 0.0547)
Step 1000: Loss 1.9931 (CE: 1.9368, Aux: 0.0564)
Step 2000: Loss 1.9263 (CE: 1.8706, Aux: 0.0556)


## Inference
Generate text using the trained model to qualitatively assess performance.

In [None]:
def generate_text(model, tokenizer, start_text="ROMEO:", max_length=200, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device
    
    # Encode start text
    input_ids = tokenizer.encode(start_text).to(device)
    
    # We need to track doc_ids, assuming single document for generation
    doc_ids = torch.zeros_like(input_ids).to(device)
    
    generated_ids = input_ids.clone()
    
    print(f"Generating from: '{start_text}'")
    
    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass
            logits, _ = model(generated_ids, doc_ids)
            
            # Get logits for the last token
            next_token_logits = logits[-1, :] / temperature
            
            # Sample
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append
            generated_ids = torch.cat([generated_ids, next_token], dim=0)
            doc_ids = torch.cat([doc_ids, torch.zeros(1, device=device, dtype=doc_ids.dtype)], dim=0)
            
    # Decode
    return tokenizer._decode(generated_ids)

# Test generation
print("-" * 50)
print(generate_text(model, tokenizer, start_text="ROMEO:", max_length=500))
print("-" * 50)

--------------------------------------------------
Generating from: 'ROMEO:'


W1127 20:19:12.059000 263834 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8)
W1127 20:19:12.059000 263834 torch/_dynamo/convert_frame.py:1358] [0/8]    function: 'flex_attention' (/media/john/Tertiary/Projects/ML/BayesianFlowNet/.venv/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py:1449)
W1127 20:19:12.059000 263834 torch/_dynamo/convert_frame.py:1358] [0/8]    last reason: 0/7: tensor 'block_mask.q_indices' size mismatch at index 2. expected 1, actual 2
W1127 20:19:12.059000 263834 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1127 20:19:12.059000 263834 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html

SOLUTION: Use torch.compile(flex_attention)(...)

If you want to debug your score_mod/mask_mod, you can set:
torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMP

romeo:<UNK><UNK>ens t assw r dnth hr lldoeley n rtitetne aangsnefteniu<UNK>ecetes ,<UNK>snemotor uatur i .<UNK>pzbmrtsli: yo h yoyseusuenbf<UNK>lnoddller p;leiid our oteuyovole,<UNK>l'meniod<UNK>so a rdnm kin pete ei 'sn mebitsuaem,<UNK>fou'ebltwn: m<UNK>d fyishrn wodhnrali?<UNK>l<UNK>fmele, moly pm tho h  fteuwbl ekbu dsmicu d;at yo lley o nehes ogrf zchsp her:<UNK>venh mas yo:<UNK>ih thlt s' f; in, ;orle wa or son:<UNK>n whidait aacnneot har yoy p a t tktt'ys'ay k eseor, an sudr meat cgnefniesireilo gto hald lstlcre.'aeebunst.,<UNK>hor hsdtul , ellilr.<UNK>onnr .nts
--------------------------------------------------
