In [3]:
import torch 
import torch.nn as nn
from torch import Tensor
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.nn import functional as F

In [4]:

@dataclass
class ChunkConfig:
    """Configuration for chunked attention.
    
    chunk_size: the number of patches (tokens) per chunk. If None, use full context.
    context_chunks: the number of left context chunks to attend to. If None, attend to all previous chunks.
    
    """
    chunk_size: Optional[int]
    context_chunks: Optional[int] 

    def is_full_context(self) -> bool:
        return self.chunk_size is None
    
class ChunkConfigSampler:
    def __init__(
        self,
        *,
        chunk_size_range: Tuple[int, int],
        context_sec_range: Tuple[float, float],
        timestep_duration_sec: float,
        chunkwise_prob: float = 1.0,
        left_constrain_prob: float = 1.0,
        seed: Optional[int] = None,
    ) -> None:
        
        chunk_size_min, chunk_size_max = chunk_size_range
        if chunk_size_max < chunk_size_min:
            raise ValueError(f"Chunk size range fault: max size is {chunk_size_max} and min size is {chunk_size_min}")
        
        context_sec_min, context_sec_max = context_sec_range
        if context_sec_max < context_sec_min:
            raise ValueError(f"Context sec range fault: max size is {context_sec_max} and min size is {context_sec_min}")
        
        self.chunk_size_range = chunk_size_range
        self.context_sec_range = context_sec_range
        
        if timestep_duration_sec <= 0:
            raise ValueError("timestep_duration_sec must be positive.")
        self.timestep_duration_sec = timestep_duration_sec

        self.left_constrain_prob = max(0.0, min(1.0, float(left_constrain_prob)))
        self.chunkwise_prob = max(0.0, min(1.0, float(chunkwise_prob)))
        self._rng = random.Random(seed)

    # --- MODIFICATION: Combined into one function ---
    def _sample_range(self, range_values: Optional[Tuple[float, float]], dtype: str) -> Optional[float]:
        """
        Samples a value from a given range, casting to the specified dtype.
        dtype: 'int' or 'float'
        """
        if range_values is None:
            return None
        low, high = range_values

        if low == float('inf') or high == float('inf'):
            return float('inf')

        # Handle type-specific logic
        if dtype == 'int':
            low = max(0, int(low))
            high = max(low, int(high))
            if low == high:
                return low
            return self._rng.randint(low, high) # Sample int
        
        elif dtype == 'float':
            low = max(0.0, float(low))
            high = max(low, float(high))
            if low == high:
                return low
            return self._rng.uniform(low, high) # Sample float
        
        else:
            raise ValueError(f"Unsupported dtype: {dtype}")

    # --- MODIFICATION: Updated sample() to use the new function ---
    def sample(self) -> ChunkConfig:
        if self._rng.random() > self.chunkwise_prob:
            # Case for no chunking
            return ChunkConfig(chunk_size=None, context_chunks=None)

        # 1. Sample chunk size (as int)
        chunk_size = self._sample_range(self.chunk_size_range, dtype='int')
        
        if chunk_size == float('inf'):
            return ChunkConfig(chunk_size=None, context_chunks=None)
        
        chunk_size = max(1, int(chunk_size))

        # 2. Determine context
        if self.left_constrain_prob < 1.0 and self._rng.random() > self.left_constrain_prob:
            # Case: No left-constrained context
            context_chunks = None
        else:
            # Case: Left-constrained context
            # Sample context_sec (as float)
            context_sec = self._sample_range(self.context_sec_range, dtype='float')
        
            if context_sec is None or context_sec == float('inf'):
                context_chunks = None
            else:
                # The key calculation
                total_context_timesteps = context_sec / self.timestep_duration_sec
                context_chunks = math.ceil(total_context_timesteps / chunk_size)
        
        return ChunkConfig(chunk_size=chunk_size, context_chunks=context_chunks)
    
    
def create_dynamic_chunk_mask(seq_len: int, config: ChunkConfig, device=None):
    
    """
    seq_len: sequence length T after padding
    config: ChunkConfig object defining chunk_size and context_chunks   
    """
    
    if config.is_full_context():
        return None

    chunk_size = max(1, min(int(config.chunk_size), seq_len))
    chunk_ids = torch.arange(seq_len, device=device) // chunk_size

    query_chunk_ids = chunk_ids.unsqueeze(1)  # (T, 1)
    key_chunk_ids = chunk_ids.unsqueeze(0)    # (1, T)
    
    if config.context_chunks is None:
        lower_bound = torch.zeros_like(query_chunk_ids)
        upper_bound = query_chunk_ids
    else:
        context_chunks = max(0, int(config.context_chunks))
        lower_bound = (query_chunk_ids - context_chunks).clamp(min=0)
        upper_bound = query_chunk_ids

    mask = (key_chunk_ids >= lower_bound) & (key_chunk_ids <= upper_bound)
    return mask.unsqueeze(0).unsqueeze(0)

In [5]:
chunk_config_sampler = ChunkConfigSampler(
       chunk_size_range=(1,3),
         context_sec_range=(0.4,0.4),
         timestep_duration_sec=0.1,
         chunkwise_prob=1.0,
         left_constrain_prob=1.0, seed=10)


chunk_config = chunk_config_sampler.sample()

mask = create_dynamic_chunk_mask(seq_len=10, config=chunk_config, device='cpu')
print("Chunk Config:", chunk_config)
print("Chunk Mask:\n", mask)


Chunk Config: ChunkConfig(chunk_size=2, context_chunks=2)
Chunk Mask:
 tensor([[[[ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False],
          [False, False,  True,  True,  True,  True,  True,  True, False, False],
          [False, False,  True,  True,  True,  True,  True,  True, False, False],
          [False, False, False, False,  True,  True,  True,  True,  True,  True],
          [False, False, False, False,  True,  True,  True,  True,  True,  True]]]])
