In [1]:
import torch
import torch.nn.functional as F
from typing import Tuple

def adjacent_chained_merge(
    x: torch.Tensor,
    r: int,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Merges `r` most similar adjacent token pairs, allowing for chained merging.
    If (A,B) and (B,C) are both selected, they become a single group (A,B,C).

    Args:
        x (torch.Tensor): Input tensor of shape [B, N, C].
        r (int): The number of merge operations (links) to perform.
                 The final number of removed tokens might be >= r.

    Returns:
        The merged tensor and information needed for unmerging.
    """
    if r <= 0:
        return x, (None, None)

    B, N, C = x.shape
    device = x.device
    r = min(r, N - 1)

    with torch.no_grad():
        # --- 1. Calculate Similarity ---
        left_tokens = x[:, :-1]
        right_tokens = x[:, 1:]
        sim = (F.normalize(left_tokens, p=2, dim=-1) * F.normalize(right_tokens, p=2, dim=-1)).sum(dim=-1)

        # --- 2. Select top `r` links ---
        _, top_indices = torch.topk(sim, r, dim=-1)

        # --- 3. Find Connected Components to form groups ---
        # Each token is initially in its own group.
        # Group IDs are 0, 1, 2, ..., N-1.
        group_ids = torch.arange(N, device=device).unsqueeze(0).expand(B, -1)

        # For each selected link (A, B), we merge their groups.
        # We do this iteratively by setting group_id(B) = group_id(A).
        # To make this stable, we always merge the group with the larger ID into the one with the smaller ID.
        # This requires a loop, but it's over `r` which is usually small.
        # For a fully vectorized but more complex version, a union-find algorithm would be needed.
        # This loop-based approach is easier to understand.
        for i in range(r):
            idx = top_indices[:, i]
            # Use batch-aware indexing
            batch_indices = torch.arange(B, device=device)
            left_ids = group_ids[batch_indices, idx]
            right_ids = group_ids[batch_indices, idx + 1]
            
            # Find min and max IDs for stable merging
            min_ids = torch.min(left_ids, right_ids)
            max_ids = torch.max(left_ids, right_ids)

            # Merge groups: all tokens with max_id now get min_id
            for b in range(B):
                group_ids[b, group_ids[b] == max_ids[b]] = min_ids[b]
        
        # Find the final unique group IDs and assign new sequential indices
        # `unique_ids` will be sorted. `inverse_indices` maps each original token to its new group index.
        unique_ids, inverse_indices = torch.unique(group_ids, return_inverse=True, dim=-1)
        num_merged_tokens = unique_ids.shape[-1]
        
        # --- 4. Execute Merge ---
        # Create output tensor for the merged tokens
        x_out = torch.zeros(B, num_merged_tokens, C, device=device, dtype=x.dtype)
        # Create a tensor to count the size of each group for averaging
        group_counts = torch.zeros(B, num_merged_tokens, 1, device=device, dtype=x.dtype)

        # Use scatter_add to sum up token values and counts for each group
        # The `inverse_indices` directly tell us where each original token should go.
        x_out.scatter_add_(1, inverse_indices.unsqueeze(-1).expand(-1, -1, C), x)
        group_counts.scatter_add_(1, inverse_indices.unsqueeze(-1), torch.ones_like(x[..., :1]))

        # Average the values in each group
        x_out /= group_counts

        # For unmerging, we need to know which final token corresponds to which original tokens.
        # The `inverse_indices` tensor already contains this mapping.
        unmerge_info = inverse_indices

    return x_out, unmerge_info


def adjacent_chained_unmerge(x: torch.Tensor, unmerge_info: torch.Tensor) -> torch.Tensor:
    """
    Unmerges tokens that were merged with the chained method.
    """
    if unmerge_info is None:
        return x
    
    B, _, C = x.shape
    
    # `unmerge_info` maps each original position to its merged group index.
    # We can use `gather` to broadcast the merged token values back to their original positions.
    unmerged_output = x.gather(1, unmerge_info.unsqueeze(-1).expand(-1, -1, C))
    
    return unmerged_output


# --- Example Usage ---
if __name__ == "__main__":
    batch_size = 1
    seq_len = 10
    channels = 2
    
    dummy_input = torch.randn(batch_size, seq_len, channels)
    # Make A,B,C similar: (0,1), (1,2)
    dummy_input[:, 1, :] = dummy_input[:, 0, :] + 0.1
    dummy_input[:, 2, :] = dummy_input[:, 1, :] + 0.15
    # Make D,E similar: (4,5)
    dummy_input[:, 5, :] = dummy_input[:, 4, :] + 0.2
    
    print("Original input shape:", dummy_input.shape)
    
    # Merge the top 2 most similar pairs. This should link (0,1) and (1,2).
    # Resulting groups should be {0,1,2}, {3}, {4,5}, {6}, {7}, {8}, {9}
    # This means 10 tokens become 7 tokens.
    merged_output, unmerge_info = adjacent_chained_merge(dummy_input, r=3)
    
    print("Merged output shape:", merged_output.shape)
    print("Unmerge info (group ID for each original token):", unmerge_info)
    
    unmerged_output = adjacent_chained_unmerge(merged_output, unmerge_info)
    print("Unmerged output shape:", unmerged_output.shape)
    
    # Verification
    # The values for tokens 0, 1, 2 should be the same and equal to their average
    original_avg_012 = dummy_input[:, [0,1,2], :].mean(dim=1)
    print("\nVerification for group {0,1,2}:")
    print("Original Average:", original_avg_012)
    print("Unmerged value at pos 0:", unmerged_output[:, 0, :])
    print("Unmerged value at pos 2:", unmerged_output[:, 2, :])
    print("All three positions match:", torch.allclose(unmerged_output[:, 0, :], unmerged_output[:, 1, :]) and torch.allclose(unmerged_output[:, 1, :], unmerged_output[:, 2, :]))
    print("Value matches original average:", torch.allclose(unmerged_output[:, 0, :], original_avg_012))

Original input shape: torch.Size([1, 10, 2])


RuntimeError: The expanded size of the tensor (-1) isn't allowed in a leading, non-existing dimension 0