In [17]:
import torch
from h_net_dynamic_chunking import DynamicSequenceChunker
from torch import Tensor

In [18]:
batch, seq, dim = 3, 2, 6
tokens = torch.randn(batch, seq, dim)

### Pack Chunker

In [19]:
# following section 2.2 of the paper

from collections import namedtuple

import torch
from torch import cat, arange
from torch.nested import nested_tensor
from torch.nn import Module, Linear, Parameter
from torch.nn.functional import cosine_similarity, pad

from einx import multiply
from einops import repeat, rearrange

from assoc_scan import AssocScan

# constants

Outputs = namedtuple('Outputs', [
    'downsampled',
    'upsample_fn',
    'weighted_aux_ratio_loss'
])

Intermediates = namedtuple('Intermediates', [
    'mask',
    'probs',
    'chunk_lens',
    'boundary_mask',
    'residual',
    'gates',
    'upsampler_output_scale',
    'aux_ratio_loss',
    'new_seq_lens'
])

# helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def straight_through(t, value):
    return t + (value - t).detach()

def frac_gradient(t, frac = 1.):
    if frac == 1:
        return

    t_grad = t * frac
    return straight_through(t_grad, t)

# classes

class PackDynamicSequenceChunker(Module):
    def __init__(
        self,
        dim,
        dim_queries_keys = None,
        boundary_threshold = 0.5,
        target_avg_token_length = 6.,       # N in eq(10)
        ratio_loss_weight = 3e-2,
        handle_residual_proj = False,       # turning this on will automatically handle a projection of the residual and its application in the inverse upsample function
        assoc_scan_use_accelerated = False,
        learning_rate_difference = 0.75,    # in the paper, they report that as one moves up a hierarchy, the learning rate needs to decrease. we'll default to 0.75 for the rough 2.0 -> 1.5 somewhere in the appendix from level 0 -> 1
        straight_through_frac_vecs = True,  # improvisation where F receives gradients through straight-through with sigmoid
    ):
        super().__init__()
        dim_queries_keys = default(dim_queries_keys, dim)

        # linear to queries and keys

        self.to_queries_keys = Linear(dim, dim_queries_keys * 2, bias = False)

        # start key token, so first token can be segmented / chunked out

        self.start_key_token = Parameter(torch.randn(dim_queries_keys) * 1e-2) # presumably, need a start key token for the first token, open an issue if i got it wrong

        # threshold to determine boundary

        assert 0. < boundary_threshold < 1.

        self.boundary_threshold = boundary_threshold

        # smoothing related

        self.smooth_assoc_scan = AssocScan(use_accelerated = assoc_scan_use_accelerated)

        # maybe residual proj

        self.handle_residual_proj = handle_residual_proj

        if handle_residual_proj:
            self.residual_proj = Linear(dim, dim)

        # learning rate modulation, appendix C
        # the multiplier on the learning rate as one goes from outer to inner of the h-net, and inverse of this value from inner to outer

        self.learning_rate_difference = learning_rate_difference

        # ratio aux loss related

        self.target_avg_token_length = target_avg_token_length

        self.straight_through_frac_vecs = straight_through_frac_vecs

        self.ratio_loss_weight = ratio_loss_weight

        self.register_buffer('zero', torch.tensor(0.), persistent = False)

    def upsample(
        self,
        downsampled,
        intermediates: Intermediates,
        apply_scale = True
    ):
        batch, needs_grad, device = downsampled.shape[0], downsampled.requires_grad, downsampled.device

        mask = intermediates.mask
        gates = intermediates.gates
        residual = intermediates.residual

        # smoothing module for improved gradients eq(5)

        downsampled = self.smooth_assoc_scan(gates, downsampled)

        # upsample

        downsampled_without_padding = downsampled[mask]
        chunk_lens_without_padding = intermediates.chunk_lens[mask]

        seq = arange(downsampled_without_padding.shape[0], device = device)

        repeated_indices = torch.repeat_interleave(seq, chunk_lens_without_padding, dim = 0)
        upsampled = downsampled_without_padding[repeated_indices]

        upsampled = rearrange(upsampled, '(b n) d -> b n d', b = batch)

        scale = intermediates.upsampler_output_scale

        if needs_grad and apply_scale and exists(scale):
            upsampled = multiply('b n d, b n', upsampled, scale)

        if self.handle_residual_proj:
            upsampled = upsampled + self.residual_proj(residual)

        upsampled = frac_gradient(upsampled, self.learning_rate_difference)

        return upsampled

    def forward(
        self,
        tokens, # float[b n d] or float[total_n d] if seq_lens is specified,
        seq_lens: Tensor | None = None,
        return_intermediates = False,
        return_only_chunk_lens = False
    ):
        with torch.no_grad():
            if seq_lens is not None:
                total_lens = seq_lens.sum().item()
                document_ids = torch.repeat_interleave(
                    torch.arange(len(seq_lens), device=seq_lens.device), seq_lens
                )

                # a sequence position with 1 in probs_mask is the position of the first
                # token of a new document, which means it must be a chunk start with
                # probability 1
                packed_probs_mask = torch.zeros_like(document_ids)
                packed_probs_mask[1:] = document_ids[:-1] != document_ids[1:]

                # however, since the sequence position is the start of a new document,
                # we must prevent the associative scan from reading from the token before 
                # it. To do this, we reverse probs_mask, so the sequence position that used
                # to be 1 becomes 0 and the positions that used to be 0 become 1.
                # this means that at the start of each new document, the token cannot
                # read from the token before it
                packed_gate_mask = -1 * (packed_probs_mask - 1)
                tokens = tokens.unsqueeze(0)
            else:
                packed_probs_mask = None
                packed_gate_mask = None
                document_ids = None

        batch, length, device = *tokens.shape[:2], tokens.device

        residual = tokens

        queries, keys = self.to_queries_keys(tokens).chunk(2, dim = -1)

        start_keys = repeat(self.start_key_token, 'd -> b 1 d', b = batch)

        keys = cat((start_keys, keys), dim = 1)

        if packed_probs_mask is not None:
            # when packed, the keys end up being compared incorrectly at this current stage
            # for example, suppose we have two documents of lengths 2 and 2.
            # if passed individually, each document's first token will compare against the start key token
            # however, when packed, the 3rd token (first token of second document)
            # will compare against the key of the 2nd token, resulting in a wrong cosine_similarity
            # which later impacts the probability
            # at first I thought this would be fine because we hard set the probability, however
            # now I recall that in the associative scan smoothing, this probability term is involved
            # beyond the gate itself, which would result in an incorrect calculation, so
            # we need to make all those keys that are at the start of a new document
            # equal to the start key token
            
            # first, we start by adding a 1 to the right side of the packed_probs_mask, this is to account
            # for the fact that when calculating cosine similarity, we use `keys[:, :-1]`, so it is shifted
            # so the placement of the start key token needs to be shifted as well
            packed_probs_mask_with_start = pad(packed_probs_mask, (0, 1), value = 0)

            # and now, for all sequence positions where packed_probs_mask_with_start is 1,
            # we set the corresponding keys to the start key token
            keys[:, packed_probs_mask_with_start == 1] = start_keys


        # each query looks at the previous key to determine if distance is greater than some threshold for determining a boundary exists (they use 0.5 as threshold)

        cosine_sim  = cosine_similarity(queries, keys[:, :-1], dim = -1)

        probs = (1. - cosine_sim) * 0.5 # cosine sim is -1. to 1., this transforms it to 0. to 1.

        boundary_mask = probs > self.boundary_threshold # bool[b n]

        boundary_mask[:, 0] = True # first token must always be boundary

        if packed_probs_mask is not None:
            # at all positions where the packed_probs_masking is 1, it means it is the start
            # of a new document. We must force these positions to be boundaries
            # previously I tried doing it by setting probs to 1, but that
            # will cause issues later down the line because downsampling tensor is multiplied
            # by the probs, so we must directly set the boundary mask instead
            boundary_mask = torch.where(packed_probs_mask == 1, True, boundary_mask)

        # compute some lengths, per chunk and number of chunks per batch

        num_chunks = boundary_mask.long().sum(dim = -1)

        boundary_mask_with_end = pad(boundary_mask, (0, 1), value = True)
        sel_indices = repeat(arange(boundary_mask_with_end.shape[-1], device = device), 'n -> b n', b = batch)[boundary_mask_with_end]

        sel_indices = nested_tensor(sel_indices.split((num_chunks + 1).tolist()), layout = torch.jagged, device = device)

        sel_indices = sel_indices.to_padded_tensor(padding = -1)

        mask = (sel_indices != -1)[:, 1:]

        chunk_lens = sel_indices[:, 1:] - sel_indices[:, :-1]
        chunk_lens.masked_fill_(~mask, 0)

        # early return chunk lens if using a trained module as a tokenizer

        if return_only_chunk_lens:
            return chunk_lens

        # downsampling - they show in their experiments that picking out the boundary tokens works just fine

        boundary_tokens = tokens[boundary_mask] # pick out boundary tokens

        tokens_nt = nested_tensor(boundary_tokens.split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        downsampled_tokens = tokens_nt.to_padded_tensor(padding = 0.)

        # smoothing module for improved gradients eq(5)

        print(boundary_mask)

        probs_nt = nested_tensor(probs[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        boundary_probs = probs_nt.to_padded_tensor(padding = 0.)

        gates = 1. - boundary_probs

        if packed_gate_mask is not None:
            # at all positions where the packed_gate_masking is 0, it means it is the start
            # of a new document. We must prevent associative scan from allowing
            # this starting token from reading into the past document
            # also, gradients cannot propagate through this to modify this gating, as it is
            # fixed by the document sequence
            packed_gate_mask_nt = nested_tensor(packed_gate_mask.unsqueeze(0)[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = False)
            packed_gate_masking = packed_gate_mask_nt.to_padded_tensor(padding = 1.0)
            gates = gates * packed_gate_masking

        downsampled_tokens = multiply('b n d, b n', downsampled_tokens, boundary_probs)


        # for the upsampler

        confidence = torch.where(boundary_mask, probs, 1. - probs)

        # defaults if not training

        upsampler_output_scale = None
        aux_loss = self.zero
        weighted_aux_loss = self.zero

        needs_grad = tokens.requires_grad

        if needs_grad:
            # straight through for 1. multiplier on the expanded processed boundary tokens

            upsampler_output_scale = straight_through(confidence, 1.)

            # auxiliary ratio loss in section 2.3.2, eq (10)
            # lets follow their notation

            N = self.target_avg_token_length

            F = boundary_mask.float()
            G = probs.mean(dim = -1)

            # allow for a soft F to straight through - https://arxiv.org/abs/2505.22074

            if self.straight_through_frac_vecs:
                F_soft = (probs - self.boundary_threshold).sigmoid()
                F = straight_through(F_soft, F)

            F = F.mean(dim = -1)

            aux_ratio_loss = N / (N - 1) * ((N - 1) * F * G + (1. - F) * (1. - G))

            aux_loss = aux_ratio_loss.mean()
            weighted_aux_loss = aux_loss * self.ratio_loss_weight

        # intermediates
        if document_ids is not None:
            # this minlength should not be necessary as the boundaries should 
            # guarantee that each document has at least one chunk
            new_seq_lens = torch.bincount(document_ids, weights=boundary_mask.squeeze(0).long(), minlength=len(seq_lens)).long()
        else:
            new_seq_lens = num_chunks

        intermediates = Intermediates(mask, probs, chunk_lens, boundary_mask, residual, gates, upsampler_output_scale, aux_loss, new_seq_lens)

        # return the upsample function

        def upsample(downsampled, apply_scale = True):
            downsampled_input = downsampled.unsqueeze(0) if downsampled.ndim == 2 else downsampled
            upsampled = self.upsample(downsampled_input, intermediates, apply_scale = apply_scale)
            return upsampled.squeeze(0) if downsampled.ndim == 2 else upsampled

        # adjust learning rate

        downsampled_tokens = frac_gradient(downsampled_tokens, self.learning_rate_difference ** -1)

        if packed_probs_mask is not None:
            downsampled_tokens = downsampled_tokens.squeeze(0)

        # returning

        outputs = Outputs(downsampled_tokens, upsample, weighted_aux_loss)

        if not return_intermediates:
            return outputs

        return outputs, intermediates

### Testing

In [20]:
tokens.shape

torch.Size([3, 2, 6])

In [21]:
chunker = PackDynamicSequenceChunker(dim=dim, dim_queries_keys=2)

In [22]:
output, intermediates = chunker(tokens, return_intermediates=True)

tensor([[ True,  True],
        [ True, False],
        [ True, False]])


In [23]:
downsampled = output.downsampled
upsample_fn = output.upsample_fn

In [24]:
assert upsample_fn(downsampled).shape == tokens.shape
upsampled = upsample_fn(downsampled)

In [25]:
downsampled

tensor([[[ 0.1346, -0.0466, -0.0365, -0.0531,  0.1216, -0.0309],
         [-0.7949,  2.2300, -1.1471, -1.2490,  0.9231,  0.4797]],

        [[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<AddBackward0>)

In [26]:
upsampled

tensor([[[ 0.1346, -0.0466, -0.0365, -0.0531,  0.1216, -0.0309],
         [-0.7874,  2.2274, -1.1491, -1.2519,  0.9299,  0.4780]],

        [[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
         [ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242]],

        [[-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605],
         [-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605]]],
       grad_fn=<AddBackward0>)

In [27]:
# we'll try packing the tokens into a 2d tensor
packed_tokens = tokens.view(-1, dim)
seq_lens = torch.full((batch,), seq, dtype=torch.long)

In [28]:
seq_lens

tensor([2, 2, 2])

In [29]:
total_lens = seq_lens.sum().item()
document_ids = torch.repeat_interleave(
    torch.arange(len(seq_lens), device=seq_lens.device), seq_lens
)

# a sequence position with 1 in probs_mask is the position of the first
# token of a new document, which means it must be a chunk start with
# probability 1
packed_probs_mask = torch.zeros_like(document_ids)
packed_probs_mask[1:] = document_ids[:-1] != document_ids[1:]

In [30]:
document_ids

tensor([0, 0, 1, 1, 2, 2])

In [31]:
bounds = torch.tensor([[ True,  True,  True, False,  True,  True]])

In [32]:
torch.bincount(document_ids, weights=bounds.squeeze(0)).long()

tensor([2, 1, 2])

In [33]:
packed_output, packed_intermediates = chunker(packed_tokens, seq_lens=seq_lens, return_intermediates=True)

tensor([[ True,  True,  True, False,  True, False]])


In [34]:
packed_downsampled = packed_output.downsampled
packed_upsample_fn = packed_output.upsample_fn

In [35]:
packed_downsampled.shape

torch.Size([4, 6])

In [36]:
packed_upsampled = packed_upsample_fn(packed_downsampled)

In [37]:
packed_upsampled = packed_upsampled.view(batch, seq, dim)

In [38]:
upsampled[1], packed_upsampled[1]

(tensor([[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
         [ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242]],
        grad_fn=<SelectBackward0>),
 tensor([[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
         [ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242]],
        grad_fn=<SelectBackward0>))

In [39]:
upsampled, packed_upsampled

(tensor([[[ 0.1346, -0.0466, -0.0365, -0.0531,  0.1216, -0.0309],
          [-0.7874,  2.2274, -1.1491, -1.2519,  0.9299,  0.4780]],
 
         [[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
          [ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242]],
 
         [[-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605],
          [-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605]]],
        grad_fn=<AddBackward0>),
 tensor([[[ 0.1346, -0.0466, -0.0365, -0.0531,  0.1216, -0.0309],
          [-0.7874,  2.2274, -1.1491, -1.2519,  0.9299,  0.4780]],
 
         [[ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242],
          [ 0.0138, -0.0113,  0.0095, -0.0149,  0.0135,  0.0242]],
 
         [[-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605],
          [-1.6738, -0.8417,  1.4402,  0.0344,  0.8143,  0.0605]]],
        grad_fn=<ViewBackward0>))

In [40]:
torch.allclose(upsampled, packed_upsampled)

True

### More Testing

In [41]:
import random
tokens = []
for _ in range(batch):
    length = random.randint(5, 15)
    tokens.append(torch.randn(length, dim))

In [42]:
seq_lens = torch.tensor([t.shape[0] for t in tokens], dtype=torch.long)

In [43]:
sum(seq_lens)

tensor(20)

In [44]:
packed_tokens = cat(tokens, dim=0)

In [45]:
packed_tokens.requires_grad_(True)

tensor([[ 4.9570e-01,  1.9512e+00,  8.0439e-02, -2.0457e+00,  4.1184e-01,
         -8.3131e-01],
        [-7.9625e-01,  1.3252e+00,  1.0597e-01,  1.2274e+00,  5.1952e-01,
          1.2497e+00],
        [ 9.6132e-01, -8.8569e-01, -1.3992e-01, -5.6533e-02,  4.0542e-01,
          4.3112e-01],
        [-7.3223e-01, -8.1866e-01, -4.8022e-01,  8.8180e-01,  7.6478e-01,
          1.1485e+00],
        [-8.1374e-01,  8.1331e-01,  2.1135e-01,  3.5361e-01,  8.8226e-01,
          2.7702e-01],
        [ 9.4046e-01,  1.6876e+00,  2.7494e-01, -1.1740e+00,  8.6801e-02,
         -9.2056e-01],
        [-1.2628e+00, -8.9110e-01,  1.2257e+00,  4.7461e-01, -9.7161e-02,
         -4.8178e-01],
        [-1.4138e+00,  3.8979e-01, -1.0428e+00, -1.4157e+00, -1.9734e+00,
         -5.5926e-01],
        [ 1.0177e+00,  5.1015e-01, -5.8073e-01, -1.2103e+00,  1.5511e+00,
          1.7335e+00],
        [-9.2158e-02, -1.4642e+00, -1.1722e-01, -4.7172e-01, -1.0320e+00,
         -1.5299e+00],
        [ 3.7557e-01,  8.7319e

In [46]:
assert packed_tokens.shape == (sum(seq_lens), dim)

In [47]:
chunker = DynamicSequenceChunker(dim=dim)

In [48]:
ref_chunks = []
for t in tokens:
    out, intermediates = chunker(t.unsqueeze(0), return_intermediates=True)
    ref_chunks.append(out.downsampled.squeeze(0))

In [49]:
pack_chunker = PackDynamicSequenceChunker(dim=dim)
pack_chunker.load_state_dict(chunker.state_dict())

<All keys matched successfully>

In [50]:
packed_chunks, packed_intermediates = pack_chunker(packed_tokens, seq_lens=seq_lens, return_intermediates=True)

tensor([[ True,  True,  True,  True,  True, False,  True, False,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True, False, False]])


In [51]:
packed_intermediates.new_seq_lens

tensor([5, 8, 3])

In [52]:
packed_chunks.downsampled

tensor([[ 0.3005,  1.1827,  0.0488, -1.2400,  0.2496, -0.5039],
        [-0.4785,  0.7964,  0.0637,  0.7376,  0.3122,  0.7511],
        [ 0.6448, -0.5941, -0.0938, -0.0379,  0.2719,  0.2892],
        [-0.4594, -0.5136, -0.3013,  0.5532,  0.4798,  0.7206],
        [-0.4844,  0.4842,  0.1258,  0.2105,  0.5252,  0.1649],
        [-1.1782, -0.8314,  1.1435,  0.4428, -0.0907, -0.4495],
        [ 0.8711,  0.4367, -0.4971, -1.0360,  1.3277,  1.4838],
        [-0.0482, -0.7651, -0.0613, -0.2465, -0.5393, -0.7995],
        [ 0.2983,  0.6935,  0.0791, -0.9774,  0.8156, -0.4488],
        [ 0.9627, -0.0996,  0.1340, -0.3364, -0.2998, -0.6684],
        [ 0.0784, -0.6903,  0.6434, -0.0922,  0.5986, -0.6597],
        [-0.2476,  0.0046,  0.8370, -0.9601, -0.3629,  0.4958],
        [ 0.6602,  0.2894,  0.2847,  0.7596, -0.5346, -0.6369],
        [ 0.2908,  0.5661,  0.2888,  0.2686,  0.8302, -0.1448],
        [-0.0016,  0.7711,  0.0979,  0.3845, -0.7243,  0.2468],
        [-0.6159, -1.1330, -1.3809, -1.4

In [53]:
# if we reshape packed_chunks to a list of tensors, make it nested, pad with zeros, and compare to ref_chunks, it should be equal
split = packed_chunks.downsampled.split(packed_intermediates.new_seq_lens.tolist())

In [54]:
packed_intermediates.new_seq_lens

tensor([5, 8, 3])

In [55]:
split

(tensor([[ 0.3005,  1.1827,  0.0488, -1.2400,  0.2496, -0.5039],
         [-0.4785,  0.7964,  0.0637,  0.7376,  0.3122,  0.7511],
         [ 0.6448, -0.5941, -0.0938, -0.0379,  0.2719,  0.2892],
         [-0.4594, -0.5136, -0.3013,  0.5532,  0.4798,  0.7206],
         [-0.4844,  0.4842,  0.1258,  0.2105,  0.5252,  0.1649]],
        grad_fn=<SplitWithSizesBackward0>),
 tensor([[-1.1782, -0.8314,  1.1435,  0.4428, -0.0907, -0.4495],
         [ 0.8711,  0.4367, -0.4971, -1.0360,  1.3277,  1.4838],
         [-0.0482, -0.7651, -0.0613, -0.2465, -0.5393, -0.7995],
         [ 0.2983,  0.6935,  0.0791, -0.9774,  0.8156, -0.4488],
         [ 0.9627, -0.0996,  0.1340, -0.3364, -0.2998, -0.6684],
         [ 0.0784, -0.6903,  0.6434, -0.0922,  0.5986, -0.6597],
         [-0.2476,  0.0046,  0.8370, -0.9601, -0.3629,  0.4958],
         [ 0.6602,  0.2894,  0.2847,  0.7596, -0.5346, -0.6369]],
        grad_fn=<SplitWithSizesBackward0>),
 tensor([[ 0.2908,  0.5661,  0.2888,  0.2686,  0.8302, -0.1448],


In [56]:
from torch.nested import nested_tensor, to_padded_tensor

In [57]:
ref_chunks_padded = to_padded_tensor(nested_tensor(ref_chunks), 0.0)

  return _nested.nested_tensor(


In [58]:
split_padded = to_padded_tensor(nested_tensor(list(split)), 0.0)

In [59]:
ref_chunks_padded.shape, split_padded.shape

(torch.Size([3, 8, 6]), torch.Size([3, 8, 6]))

In [60]:
torch.allclose(ref_chunks_padded, split_padded)

True

In [61]:
upsampled_packed = packed_chunks.upsample_fn(packed_chunks.downsampled)

In [62]:
l = upsampled_packed.mean() + packed_chunks.weighted_aux_ratio_loss

In [63]:
l.backward()