In [1]:
import torch 
import torch.nn as nn
from torch import Tensor
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn import functional as F

In [2]:
@dataclass
class ChunkConfig:
    """Configuration for chunked attention.
    
    chunk_size: the number of patches (tokens) per chunk. If None, use full context.
    context_chunks: the number of left context chunks to attend to. If None, attend to all previous chunks.
    
    """
    chunk_size: Optional[Union[int, float]]
    context_chunks: Optional[int] 

    def is_full_context(self) -> bool:
        return self.chunk_size is None
    
    def is_causal_attention(self) -> bool:
        """
        Returns True if this config is for full-context CAUSAL attention.
        This is distinct from chunk_size=inf, which is full-context BIDIRECTIONAL.
        """
        return self.chunk_size is None
    
class ChunkConfigSampler:
    def __init__(
        self,
        *,
        chunk_size_range: Tuple[Union[int, float], Union[int, float]],
        context_chunks_range: Tuple[int, int],
        chunkwise_prob: float = 1.0,
        left_constrain_prob: float = 1.0,
        seed: Optional[int] = None,
    ) -> None:
        
        """
        chunk_size_range: Tuple (min_chunk_size, max_chunk_size). Use float('inf') for infinite chunk size.
        context_chunks_range: Tuple (min_context_chunks, max_context_chunks).
        chunkwise_prob: Probability of using chunked attention. (0.0 = no chunking, 1.0 = always chunking)
        left_constrain_prob: Probability of using left-constrained context. (0.0 = no left constraint, 1.0 = always left constraint)
        seed: Optional random seed for reproducibility.
        """
        
        
        chunk_size_min, chunk_size_max = chunk_size_range
        if chunk_size_max < chunk_size_min:
            raise ValueError(f"Chunk size range fault: max size is {chunk_size_max} and min size is {chunk_size_min}")
        context_chunks_min, context_chunks_max = context_chunks_range
        if context_chunks_max < context_chunks_min:
            raise ValueError(f"Context chunks range fault: max size is {context_chunks_max} and min size is {context_chunks_min}")
        # Store the new chunk size range
        self.chunk_size_range = chunk_size_range
        # Store the new context range
        self.context_chunks_range = context_chunks_range

        self.left_constrain_prob = max(0.0, min(1.0, float(left_constrain_prob)))
        self.chunkwise_prob = max(0.0, min(1.0, float(chunkwise_prob)))
        self._rng = random.Random(seed)

    def _sample_range(self, range_values: Optional[Tuple[int, int]]) -> Optional[int]:
        if range_values is None:
            return None
        low, high = range_values
        
        # Handle the infinite case
        if low == float('inf'):
            # Assume if low is inf, high is also inf
            return float('inf')

        # Handle finite (int) case
        low = max(0, int(low))
        high = max(low, int(high))
        if low == high:
            return low

        return self._rng.randint(low, high)

    def sample(self) -> ChunkConfig:
        if self.chunkwise_prob < 1.0 and self._rng.random() > self.chunkwise_prob:
            # Case for no chunking. run in full context mode
            return ChunkConfig(chunk_size=None, context_chunks=None)

        # Sample the single chunk size value
        chunk_size = self._sample_range(self.chunk_size_range)
        # Sample the single context_chunks value
        if self.left_constrain_prob < 1.0 and self._rng.random() > self.left_constrain_prob:
            # Case 1: for no left-constrained chunking context
            context_chunks = None   # “no limit” case
        else:
            # Case 2: for left-constrained chunking context
            context_chunks = self._sample_range(self.context_chunks_range)
        
        return ChunkConfig(chunk_size=chunk_size, context_chunks=context_chunks)


def create_dynamic_chunk_mask(seq_len: int, config: ChunkConfig, device=None):
    
    """
    seq_len: sequence length T after padding
    config: ChunkConfig object defining chunk_size and context_chunks   
    """
    
    if config.is_full_context():
        return None

    chunk_size = max(1, min(int(config.chunk_size), seq_len))
    chunk_ids = torch.arange(seq_len, device=device) // chunk_size

    query_chunk_ids = chunk_ids.unsqueeze(1)  # (T, 1)
    key_chunk_ids = chunk_ids.unsqueeze(0)    # (1, T)
    
    if config.context_chunks is None:
        lower_bound = torch.zeros_like(query_chunk_ids)
        upper_bound = query_chunk_ids
    else:
        context_chunks = max(0, int(config.context_chunks))
        lower_bound = (query_chunk_ids - context_chunks).clamp(min=0)
        upper_bound = query_chunk_ids

    mask = (key_chunk_ids >= lower_bound) & (key_chunk_ids <= upper_bound)
    return mask.unsqueeze(0).unsqueeze(0)

In [12]:
chunk_config_sampler = ChunkConfigSampler(
       chunk_size_range=(1,3),
         context_chunks_range=(1,2),
            chunkwise_prob=1.0,
            left_constrain_prob=1.0,
            seed=None)


chunk_config = chunk_config_sampler.sample()

mask = create_dynamic_chunk_mask(seq_len=10, config=chunk_config, device='cpu')
print("Chunk Config:", chunk_config)
print("Chunk Mask:\n", mask)


Chunk Config: ChunkConfig(chunk_size=1, context_chunks=2)
Chunk Mask:
 tensor([[[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [False,  True,  True,  True, False, False, False, False, False, False],
          [False, False,  True,  True,  True, False, False, False, False, False],
          [False, False, False,  True,  True,  True, False, False, False, False],
          [False, False, False, False,  True,  True,  True, False, False, False],
          [False, False, False, False, False,  True,  True,  True, False, False],
          [False, False, False, False, False, False,  True,  True,  True, False],
          [False, False, False, False, False, False, False,  True,  True,  True]]]])
