In [339]:
import torch
from h_net_dynamic_chunking import DynamicSequenceChunker
from torch import Tensor

In [340]:
batch, seq, dim = 3, 2, 6
tokens = torch.randn(batch, seq, dim)

### Pack Chunker

In [341]:
# following section 2.2 of the paper

from collections import namedtuple

import torch
from torch import cat, arange
from torch.nested import nested_tensor
from torch.nn import Module, Linear, Parameter
from torch.nn.functional import cosine_similarity, pad

from einx import multiply
from einops import repeat, rearrange

from assoc_scan import AssocScan

# constants

Outputs = namedtuple('Outputs', [
    'downsampled',
    'upsample_fn',
    'weighted_aux_ratio_loss'
])

Intermediates = namedtuple('Intermediates', [
    'mask',
    'probs',
    'chunk_lens',
    'boundary_mask',
    'residual',
    'gates',
    'upsampler_output_scale',
    'aux_ratio_loss'
])

# helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def straight_through(t, value):
    return t + (value - t).detach()

def frac_gradient(t, frac = 1.):
    if frac == 1:
        return

    t_grad = t * frac
    return straight_through(t_grad, t)

# classes

class PackDynamicSequenceChunker(Module):
    def __init__(
        self,
        dim,
        dim_queries_keys = None,
        boundary_threshold = 0.5,
        target_avg_token_length = 6.,       # N in eq(10)
        ratio_loss_weight = 3e-2,
        handle_residual_proj = False,       # turning this on will automatically handle a projection of the residual and its application in the inverse upsample function
        assoc_scan_use_accelerated = False,
        learning_rate_difference = 0.75,    # in the paper, they report that as one moves up a hierarchy, the learning rate needs to decrease. we'll default to 0.75 for the rough 2.0 -> 1.5 somewhere in the appendix from level 0 -> 1
        straight_through_frac_vecs = True,  # improvisation where F receives gradients through straight-through with sigmoid
    ):
        super().__init__()
        dim_queries_keys = default(dim_queries_keys, dim)

        # linear to queries and keys

        self.to_queries_keys = Linear(dim, dim_queries_keys * 2, bias = False)

        # start key token, so first token can be segmented / chunked out

        self.start_key_token = Parameter(torch.randn(dim_queries_keys) * 1e-2) # presumably, need a start key token for the first token, open an issue if i got it wrong

        # threshold to determine boundary

        assert 0. < boundary_threshold < 1.

        self.boundary_threshold = boundary_threshold

        # smoothing related

        self.smooth_assoc_scan = AssocScan(use_accelerated = assoc_scan_use_accelerated)

        # maybe residual proj

        self.handle_residual_proj = handle_residual_proj

        if handle_residual_proj:
            self.residual_proj = Linear(dim, dim)

        # learning rate modulation, appendix C
        # the multiplier on the learning rate as one goes from outer to inner of the h-net, and inverse of this value from inner to outer

        self.learning_rate_difference = learning_rate_difference

        # ratio aux loss related

        self.target_avg_token_length = target_avg_token_length

        self.straight_through_frac_vecs = straight_through_frac_vecs

        self.ratio_loss_weight = ratio_loss_weight

        self.register_buffer('zero', torch.tensor(0.), persistent = False)

    def upsample(
        self,
        downsampled,
        intermediates: Intermediates,
        apply_scale = True
    ):
        batch, needs_grad, device = downsampled.shape[0], downsampled.requires_grad, downsampled.device

        mask = intermediates.mask
        gates = intermediates.gates
        residual = intermediates.residual

        # smoothing module for improved gradients eq(5)

        downsampled = self.smooth_assoc_scan(gates, downsampled)

        # upsample

        downsampled_without_padding = downsampled[mask]
        chunk_lens_without_padding = intermediates.chunk_lens[mask]

        seq = arange(downsampled_without_padding.shape[0], device = device)

        repeated_indices = torch.repeat_interleave(seq, chunk_lens_without_padding, dim = 0)
        upsampled = downsampled_without_padding[repeated_indices]

        upsampled = rearrange(upsampled, '(b n) d -> b n d', b = batch)

        scale = intermediates.upsampler_output_scale

        if needs_grad and apply_scale and exists(scale):
            upsampled = multiply('b n d, b n', upsampled, scale)

        if self.handle_residual_proj:
            upsampled = upsampled + self.residual_proj(residual)

        upsampled = frac_gradient(upsampled, self.learning_rate_difference)

        return upsampled

    def forward(
        self,
        tokens, # float[b n d] or float[total_n d] if seq_lens is specified,
        seq_lens: Tensor | None = None,
        return_intermediates = False,
        return_only_chunk_lens = False
    ):
        with torch.no_grad():
            if seq_lens is not None:
                total_lens = seq_lens.sum().item()
                document_ids = torch.repeat_interleave(
                    torch.arange(len(seq_lens), device=seq_lens.device), seq_lens
                )

                # a sequence position with 1 in probs_mask is the position of the first
                # token of a new document, which means it must be a chunk start with
                # probability 1
                packed_probs_mask = torch.zeros_like(document_ids)
                packed_probs_mask[1:] = document_ids[:-1] != document_ids[1:]

                # however, since the sequence position is the start of a new document,
                # we must prevent the associative scan from reading from the token before 
                # it. To do this, we reverse probs_mask, so the sequence position that used
                # to be 1 becomes 0 and the positions that used to be 0 become 1.
                # this means that at the start of each new document, the token cannot
                # read from the token before it
                packed_gate_mask = -1 * (packed_probs_mask - 1)
                tokens = tokens.unsqueeze(0)
                print(f"probs mask: {packed_probs_mask}")
                print(f"gate mask: {packed_gate_mask}")
            else:
                packed_probs_mask = None
                packed_gate_mask = None

        batch, length, device = *tokens.shape[:2], tokens.device

        residual = tokens

        queries, keys = self.to_queries_keys(tokens).chunk(2, dim = -1)

        start_keys = repeat(self.start_key_token, 'd -> b 1 d', b = batch)

        keys = cat((start_keys, keys), dim = 1)

        if packed_probs_mask is not None:
            # when packed, the keys end up being compared incorrectly at this current stage
            # for example, suppose we have two documents of lengths 2 and 2.
            # if passed individually, each document's first token will compare against the start key token
            # however, when packed, the 3rd token (first token of second document)
            # will compare against the key of the 2nd token, resulting in a wrong cosine_similarity
            # which later impacts the probability
            # at first I thought this would be fine because we hard set the probability, however
            # now I recall that in the associative scan smoothing, this probability term is involved
            # beyond the gate itself, which would result in an incorrect calculation, so
            # we need to make all those keys that are at the start of a new document
            # equal to the start key token
            
            # first, we start by adding a 1 to the right side of the packed_probs_mask, this is to account
            # for the fact that when calculating cosine similarity, we use `keys[:, :-1]`, so it is shifted
            # so the placement of the start key token needs to be shifted as well
            packed_probs_mask_with_start = pad(packed_probs_mask, (0, 1), value = 0)

            # and now, for all sequence positions where packed_probs_mask_with_start is 1,
            # we set the corresponding keys to the start key token
            keys[:, packed_probs_mask_with_start == 1] = start_keys


        # each query looks at the previous key to determine if distance is greater than some threshold for determining a boundary exists (they use 0.5 as threshold)

        cosine_sim  = cosine_similarity(queries, keys[:, :-1], dim = -1)

        probs = (1. - cosine_sim) * 0.5 # cosine sim is -1. to 1., this transforms it to 0. to 1.

        # if packed_probs_mask is not None:
        #     # at all positions where packed_probs_mask is 1, it is the start of a new document
        #     # which means it must also be the start of a new chunk
        #     probs = torch.where(packed_probs_mask == 1, 1, probs)

        boundary_mask = probs > self.boundary_threshold # bool[b n]

        boundary_mask[:, 0] = True # first token must always be boundary

        if packed_probs_mask is not None:
            # at all positions where the packed_probs_masking is 1, it means it is the start
            # of a new document. We must force these positions to be boundaries
            # previously I tried doing it by setting probs to 1, but that
            # will cause issues later down the line because downsampling tensor is multiplied
            # by the probs, so we must directly set the boundary mask instead
            boundary_mask = torch.where(packed_probs_mask == 1, True, boundary_mask)

        print(f"boundary mask: {boundary_mask}")

        # compute some lengths, per chunk and number of chunks per batch

        num_chunks = boundary_mask.long().sum(dim = -1)

        boundary_mask_with_end = pad(boundary_mask, (0, 1), value = True)
        sel_indices = repeat(arange(boundary_mask_with_end.shape[-1], device = device), 'n -> b n', b = batch)[boundary_mask_with_end]

        sel_indices = nested_tensor(sel_indices.split((num_chunks + 1).tolist()), layout = torch.jagged, device = device)

        sel_indices = sel_indices.to_padded_tensor(padding = -1)

        mask = (sel_indices != -1)[:, 1:]

        chunk_lens = sel_indices[:, 1:] - sel_indices[:, :-1]
        chunk_lens.masked_fill_(~mask, 0)

        # early return chunk lens if using a trained module as a tokenizer

        if return_only_chunk_lens:
            return chunk_lens

        # downsampling - they show in their experiments that picking out the boundary tokens works just fine

        boundary_tokens = tokens[boundary_mask] # pick out boundary tokens

        tokens_nt = nested_tensor(boundary_tokens.split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        downsampled_tokens = tokens_nt.to_padded_tensor(padding = 0.)

        # smoothing module for improved gradients eq(5)

        probs_nt = nested_tensor(probs[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = True)

        boundary_probs = probs_nt.to_padded_tensor(padding = 0.)
        print(f"boundary probs: {boundary_probs}")

        gates = 1. - boundary_probs

        if packed_gate_mask is not None:
            # at all positions where the packed_gate_masking is 0, it means it is the start
            # of a new document. We must prevent associative scan from allowing
            # this starting token from reading into the past document
            # also, gradients cannot propagate through this to modify this gating, as it is
            # fixed by the document sequence
            packed_gate_mask_nt = nested_tensor(packed_gate_mask.unsqueeze(0)[boundary_mask].split(num_chunks.tolist()), layout = torch.jagged, device = device, requires_grad = False)
            packed_gate_masking = packed_gate_mask_nt.to_padded_tensor(padding = 1.0)
            gates = gates * packed_gate_masking

        downsampled_tokens = multiply('b n d, b n', downsampled_tokens, boundary_probs)

        print(f"downsampled tokens: {downsampled_tokens}")

        # for the upsampler

        confidence = torch.where(boundary_mask, probs, 1. - probs)

        # defaults if not training

        upsampler_output_scale = None
        aux_loss = self.zero
        weighted_aux_loss = self.zero

        needs_grad = tokens.requires_grad

        if needs_grad:
            # straight through for 1. multiplier on the expanded processed boundary tokens

            upsampler_output_scale = straight_through(confidence, 1.)

            # auxiliary ratio loss in section 2.3.2, eq (10)
            # lets follow their notation

            N = self.target_avg_token_length

            F = boundary_mask.float()
            G = probs.mean(dim = -1)

            # allow for a soft F to straight through - https://arxiv.org/abs/2505.22074

            if self.straight_through_frac_vecs:
                F_soft = (probs - self.boundary_threshold).sigmoid()
                F = straight_through(F_soft, F)

            F = F.mean(dim = -1)

            aux_ratio_loss = N / (N - 1) * ((N - 1) * F * G + (1. - F) * (1. - G))

            aux_loss = aux_ratio_loss.mean()
            weighted_aux_loss = aux_loss * self.ratio_loss_weight

        # intermediates

        intermediates = Intermediates(mask, probs, chunk_lens, boundary_mask, residual, gates, upsampler_output_scale, aux_loss)

        # return the upsample function

        def upsample(downsampled, apply_scale = True):

            return self.upsample(downsampled, intermediates, apply_scale = apply_scale)

        # adjust learning rate

        downsampled_tokens = frac_gradient(downsampled_tokens, self.learning_rate_difference ** -1)

        # returning

        outputs = Outputs(downsampled_tokens, upsample, weighted_aux_loss)

        if not return_intermediates:
            return outputs

        return outputs, intermediates

### Testing

In [342]:
tokens.shape

torch.Size([3, 2, 6])

In [343]:
chunker = PackDynamicSequenceChunker(dim=dim, dim_queries_keys=2)

In [344]:
output, intermediates = chunker(tokens, return_intermediates=True)

boundary mask: tensor([[ True, False],
        [ True,  True],
        [ True, False]])
boundary probs: tensor([[0.1747, 0.0000],
        [0.1984, 0.8700],
        [0.9648, 0.0000]], grad_fn=<ToPaddedTensorBackward0>)
downsampled tokens: tensor([[[-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
         [-1.2059, -0.7567, -1.7270, -0.0221,  0.4061,  0.9071]],

        [[ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<MulBackward0>)


In [345]:
downsampled = output.downsampled
upsample_fn = output.upsample_fn

In [346]:
assert upsample_fn(downsampled).shape == tokens.shape
upsampled = upsample_fn(downsampled)

In [347]:
upsampled

tensor([[[-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454],
         [-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454]],

        [[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
         [-1.1808, -0.7604, -1.7473, -0.0303,  0.3919,  0.9136]],

        [[ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797],
         [ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797]]],
       grad_fn=<AddBackward0>)

In [348]:
# we'll try packing the tokens into a 2d tensor
packed_tokens = tokens.view(-1, dim)
seq_lens = torch.full((batch,), seq, dtype=torch.long)

In [349]:
seq_lens

tensor([2, 2, 2])

In [350]:
packed_output, packed_intermediates = chunker(packed_tokens, seq_lens=seq_lens, return_intermediates=True)

probs mask: tensor([0, 0, 1, 0, 1, 0])
gate mask: tensor([1, 1, 0, 1, 0, 1])
boundary mask: tensor([[ True, False,  True,  True,  True, False]])
boundary probs: tensor([[0.1747, 0.1984, 0.8700, 0.9648]], grad_fn=<ToPaddedTensorBackward0>)
downsampled tokens: tensor([[[-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454],
         [ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
         [-1.2059, -0.7567, -1.7270, -0.0221,  0.4061,  0.9071],
         [ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797]]],
       grad_fn=<MulBackward0>)


In [351]:
packed_downsampled = packed_output.downsampled
packed_upsample_fn = packed_output.upsample_fn

In [352]:
packed_upsampled = packed_upsample_fn(packed_downsampled)

In [353]:
packed_upsampled = packed_upsampled.view(batch, seq, dim)

In [354]:
upsampled[1], packed_upsampled[1]

(tensor([[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
         [-1.1808, -0.7604, -1.7473, -0.0303,  0.3919,  0.9136]],
        grad_fn=<SelectBackward0>),
 tensor([[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
         [-1.1808, -0.7604, -1.7473, -0.0303,  0.3919,  0.9136]],
        grad_fn=<SelectBackward0>))

In [355]:
upsampled, packed_upsampled

(tensor([[[-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454],
          [-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454]],
 
         [[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
          [-1.1808, -0.7604, -1.7473, -0.0303,  0.3919,  0.9136]],
 
         [[ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797],
          [ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797]]],
        grad_fn=<AddBackward0>),
 tensor([[[-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454],
          [-0.1083,  0.0021,  0.1231,  0.0450, -0.1254,  0.1454]],
 
         [[ 0.1933, -0.0288, -0.1559, -0.0629, -0.1093,  0.0496],
          [-1.1808, -0.7604, -1.7473, -0.0303,  0.3919,  0.9136]],
 
         [[ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797],
          [ 0.6433,  0.1581,  1.3074, -0.2972,  0.2911, -0.6797]]],
        grad_fn=<ViewBackward0>))

In [356]:
torch.allclose(upsampled, packed_upsampled)

True

### More Testing

In [14]:
import random
tokens = []
for _ in range(batch):
    length = random.randint(5, 15)
    tokens.append(torch.randn(length, dim))

In [15]:
seq_lens = torch.tensor([t.shape[0] for t in tokens], dtype=torch.long)

In [16]:
sum(seq_lens)

tensor(46)

In [17]:
packed_tokens = cat(tokens, dim=0)

In [18]:
packed_tokens.requires_grad_(True)

tensor([[-0.2811, -0.0183, -0.7339,  1.3912, -0.0349,  2.0000],
        [-1.6068,  1.2719,  0.6986, -0.1323,  0.1171,  0.2591],
        [ 0.2812,  2.1973, -0.5555, -0.1148, -0.3636,  1.5843],
        [ 0.6419,  0.7335,  0.5826, -1.4516, -0.2430, -1.3605],
        [ 0.8433,  0.2756, -0.0368,  0.1933, -0.4746,  0.1494],
        [-1.0968,  1.1501, -0.6996,  0.4604,  2.6526, -1.3864],
        [ 0.5884,  0.5012, -0.8607, -0.4995,  1.7886,  0.1358],
        [ 0.5600,  1.8541,  0.1609, -0.2951,  1.8283, -1.4094],
        [ 0.3431, -0.9126,  0.1970, -0.0537,  0.6012, -0.2392],
        [-0.6882, -0.2753, -0.4516,  0.9160,  0.0528,  0.2800],
        [ 0.3801,  0.3738,  1.3966,  0.0368,  1.0898, -0.7166],
        [ 0.5416, -0.9451, -0.2158,  0.6629, -1.5498, -0.3925],
        [-1.1886, -0.2104,  1.9907, -0.6726,  0.0207,  0.0789],
        [ 0.8497,  0.4604,  1.0230, -1.3509, -1.6907,  1.4050],
        [-1.9700, -0.8096, -0.1399, -1.6950, -0.7268,  0.5012],
        [-2.2096,  0.7422, -1.0403,  1.1

In [19]:
assert packed_tokens.shape == (sum(seq_lens), dim)

In [20]:
chunker = DynamicSequenceChunker(dim=dim)

In [21]:
ref_chunks = []
for t in tokens:
    out, intermediates = chunker(t.unsqueeze(0), return_intermediates=True)
    ref_chunks.append(out.downsampled.squeeze(0))

In [22]:
pack_chunker = PackDynamicSequenceChunker(dim=dim)
pack_chunker.load_state_dict(chunker.state_dict())

<All keys matched successfully>

In [23]:
packed_chunks, packed_intermediates = pack_chunker(packed_tokens, seq_lens=seq_lens, return_intermediates=True)

In [24]:
packed_intermediates.new_seq_lens

tensor([7, 4, 7, 8])

In [25]:
# if we reshape packed_chunks to a list of tensors, make it nested, pad with zeros, and compare to ref_chunks, it should be equal
split = packed_chunks.downsampled.split(packed_intermediates.new_seq_lens.tolist())

In [26]:
split

(tensor([[-0.1828, -0.0119, -0.4771,  0.9044, -0.0227,  1.3002],
         [-1.2150,  0.9617,  0.5283, -0.1001,  0.0885,  0.1959],
         [ 0.2203,  1.7217, -0.4353, -0.0899, -0.2849,  1.2414],
         [ 0.3581,  0.4092,  0.3250, -0.8098, -0.1355, -0.7590],
         [ 0.3020, -0.8034,  0.1734, -0.0472,  0.5293, -0.2106],
         [-0.3925, -0.1570, -0.2576,  0.5225,  0.0301,  0.1597],
         [ 0.2837,  0.2790,  1.0425,  0.0275,  0.8134, -0.5349]],
        grad_fn=<SplitWithSizesBackward0>),
 tensor([[-0.5576,  0.1873, -0.2625,  0.2836,  0.3786,  0.0533],
         [-0.4788, -0.7717, -0.0710,  0.5091, -1.0340,  0.0134],
         [-0.3860,  2.1308,  0.1816,  1.3300, -0.4545, -0.3480],
         [ 0.8047,  0.4115,  0.1109,  0.0710,  0.1791, -0.5599]],
        grad_fn=<SplitWithSizesBackward0>),
 tensor([[-0.1775, -0.0645, -0.0939, -0.3212,  0.0046,  0.0920],
         [-0.0763,  0.7088, -0.9967,  0.8127,  0.0329, -0.7109],
         [ 0.1348,  0.6369,  0.2630, -0.8723,  0.2338, -0.6997],


In [27]:
from torch.nested import nested_tensor, to_padded_tensor

In [28]:
ref_chunks_padded = to_padded_tensor(nested_tensor(ref_chunks), 0.0)

  return _nested.nested_tensor(


In [29]:
split_padded = to_padded_tensor(nested_tensor(list(split)), 0.0)

In [30]:
torch.allclose(ref_chunks_padded, split_padded)

True

In [31]:
upsampled_packed = packed_chunks.upsample_fn(packed_chunks.downsampled)

In [32]:
packed_chunks.weighted_aux_ratio_loss.backward()

In [None]:
l = upsampled_packed.mean() + packed_chunks.weighted_aux_ratio_loss

tensor(0.2170, grad_fn=<AddBackward0>)

In [71]:
l.backward()