# Sequence packer

Seuqnce packing for more efficiency.

## Modernbert Source code

From https://github.dev/AnswerDotAI/ModernBERT/blob/main/src/sequence_packer.py

In [None]:
# Copyright 2024 onwards Answer.AI, LightOn, and contributors
# License: Apache-2.0

import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from typing import Generic, Iterable, NamedTuple, Optional, TypeVar, Any, Union, Sequence

import numpy as np
import torch
from numba import njit

import math, random

# from composer.core.types import Batch
# from composer.core import Time

class BatchSizeWarmupScheduler:
    def __init__(
        self,
        min_batch_size: int,
        max_batch_size: int,
        warmup_tokens,
        # warmup_tokens: Union[str, Time, int],
        world_size: int,
    ):
        self.min_batch_size = min_batch_size
        self.max_batch_size = max_batch_size

        if isinstance(warmup_tokens, str):
            # self.warmup_tokens = Time.from_timestring(warmup_tokens).value
            pass
        elif isinstance(warmup_tokens, Time):
            self.warmup_tokens = warmup_tokens.value
        else:
            self.warmup_tokens = warmup_tokens
        self.warmup_tokens = math.ceil(self.warmup_tokens / world_size)
        self._step_thresholds = self._calculate_step_thresholds()

    def _calculate_step_thresholds(self):
        total_batch_sizes = sum(range(self.min_batch_size, self.max_batch_size))
        steps_per_unit = self.warmup_tokens / total_batch_sizes

        thresholds = []
        cumsum = 0
        for batch_size in range(self.min_batch_size, self.max_batch_size):
            cumsum += batch_size
            steps = math.ceil(steps_per_unit * cumsum)
            thresholds.append(steps)
        return thresholds

    def __call__(self, current_step: int) -> int:
        if current_step >= self.warmup_tokens:
            return self.max_batch_size

        for i, threshold in enumerate(self._step_thresholds):
            if current_step < threshold:
                return self.min_batch_size + i

        # should never hit this, but just in case
        return self.max_batch_size


class SequencePackerBatchOutputTuple(NamedTuple):
    masked_pseqs: torch.Tensor
    labels: Optional[torch.Tensor]
    cu_seq_lens: list[torch.Tensor]
    max_cu_seq_len: list[torch.Tensor]


class SequencePacker(ABC):
    def __init__(
        self,
        # params defining the incoming batches of seqs
        src_iterable: Iterable[list[list[int]]],
        src_batch_size: int,
        src_max_seq_len: int,
        # params defining outgoing batches of pseqs
        out_batch_size: int,
        out_pseq_len: int,
        # params defining internal behavior
        buffer_size: int,
        pad_token_id: int = -1,
        mask_token_id: int = 0,
        ignore_token_id: int = -100,
        mask_prob: float = 0.3,
        seed=42,
        suppress_masking: bool = False,
        batch_size_warmup_min_size: Optional[int] = None,
        batch_size_warmup_tokens = None,
        # batch_size_warmup_tokens: Optional[Union[str, Time]] = None,
        world_size: int = 1,
    ):
        """
        Takes batches of unpacked, unpadded sequences (seqs) to batches of packed and padded sequences (pseqs).

        Every input batch must be a list[list[int]], a list of variable-length sequences of tokens.

        Every output batch is a tuple (masked_inputs:Tensor, labels:Tensor, seq_starts_and_end:list).

        It performs this streamwise, taking an iterable as the source of incoming batches, and
        presents itself as an iterable of outgoing batches.

        Args:
            src_iterable: An iterable (e.g., a DataLoader), whose iterator yields one incoming batch,
                        where a batch is a list of unpadded, variable-length Sequences of token
                        IDs. Since this only needs to be an Iterable, it could also be a generator object
                         like the result of `itertools.batched(dataset_list,batch_size))`

            src_batch_size:  This is the INCOMING batch size, the number of seqs in one batch yielded
                          from `src_iterable`'s iterator.

            src_max_seq_len: The maximum number of tokens in a seq within an incoming batch.

            out_batch_size: the number of pseqs (packed seqs) in one outgoing batch

            out_pseq_len: the number of tokens per packed seq, in every outgoing batch

            buffer_size: The maximum number of seqs which may be buffered internally.

            pad_token_id: The token ID used for padding the space which cannot be filled to reach out_pseq_len.

            mask_token_id: The token ID used for masking tokens in the input sequence.

            ignore_token_id: The token ID used to ignore tokens. Expected to be applied to every non-masked token, so the model only trains on predictions of masked tokens.

            suppress_masking: If True, the sequence packer will not perform masked language modeling.

            batch_size_warmup_min_size: If not None, the sequence packer will gradually increase the batch size from batch_size_warmup_min_size to out_batch_size over the course of the warmup_tokens.
                                    batch_size_warmup_min_size must be a multiple of micro_batch_size.

            batch_size_warmup_tokens: If not None, the sequence packer will gradually increase the batch size from batch_size_warmup_min_size to out_batch_size over the course of the warmup_tokens.

            world_size: The number of processes participating in this training run. batch_size_warmup_min_size is divided by this number.
        """
        assert buffer_size >= out_batch_size, f"required that {buffer_size=} >= {out_batch_size=}"
        self.src_dataloader_len = len(src_iterable)
        self.src_iterable = src_iterable
        self.src_batch_size = src_batch_size
        self.out_batch_size = out_batch_size
        self.out_pseq_len = out_pseq_len
        self.buffer_size = buffer_size
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.ignore_token_id = ignore_token_id
        self.mask_prob = mask_prob
        self.suppress_masking = suppress_masking
        # internals
        self.buffer = deque()  # internal buffer holds individual seqs, as tensors.
        # for stats to report packing efficiency.
        self._seqs_consumed = 0
        self._seqs_emitted = 0
        # Set random seed
        self.seed = seed
        self.epoch = -1
        self._token_count = 0
        self.batch_size_scheduler = None
        if batch_size_warmup_min_size is not None and batch_size_warmup_tokens is not None:
            self.batch_size_scheduler = BatchSizeWarmupScheduler(
                batch_size_warmup_min_size, out_batch_size, batch_size_warmup_tokens, world_size
            )
        else:
            self.batch_size_scheduler = None

    @property
    def seqs_emitted(self):
        "Number of seqs, incoming from src_iterable, which have been emitted in OUTGOING batches."
        return self._seqs_emitted

    @property
    def seqs_consumed(self):
        "Number of seqs, incoming from src_iterable, which have been consumed."
        return self._seqs_consumed

    def _reset_state(self):
        self.epoch += 1
        self.buffer.clear()
        self._seqs_consumed = 0
        self._seqs_emitted = 0
        self.np_rng = np.random.default_rng(self.epoch + self.seed)

        # Update the epoch for the sampler
        if isinstance(self.src_iterable, torch.utils.data.dataloader.DataLoader):
            if isinstance(self.src_iterable.sampler, torch.utils.data.distributed.DistributedSampler):
                self.src_iterable.sampler.set_epoch(self.epoch)

    def __iter__(self):
        self._reset_state()
        self.src_iterator = iter(self.src_iterable)
        return self._generate_batches()

    def __len__(self):
        # rather than estimate the packed length of the dataset, we rely on Composer's ability
        # to schedule training the using the number of batches or tokens instead of epochs.
        return None

    def _fill_buffer(self, max_items_to_add=float("inf")) -> int:
        """
        Refills the internal buffer.

        - max_items_to_add: an amount less than or equal to the number of items to add

        Returns: the number of items actually added.

        The default implementation of this simply extends to src.buffer, which is
        initialized as a list in __init__. Subclasses which want to use a different data
        structure for internal buffering should override this method and also add
        code in __init__ to initialize src.buffer appropriately.

        Any implementation of this MUST never place more than self.buffer_size items
        in the internal buffer.
        """
        items_added = 0
        # NOTE: this should be >=, kept as is to match model training code
        # TODO: change if training a new model
        while (self.buffer_size - len(self.buffer)) > self.src_batch_size:
            try:
                # if pulling another batch would fetch more than the requested max, stop
                if max_items_to_add < float("inf"):
                    if (items_added + self.src_batch_size) > max_items_to_add:
                        # print("Not adding, because of max_items_to_fetch")
                        break
                incoming_batch = next(self.src_iterator)
                assert (
                    len(incoming_batch) <= self.src_batch_size
                ), f"expected {len(incoming_batch)=} <= {self.src_batch_size=}"
                for item in incoming_batch:
                    if len(item["input_ids"]) > 0:  # ignore empty sequences
                        self.buffer.append(item["input_ids"])
                        items_added += 1
                        self._seqs_consumed += 1
            except StopIteration:
                break
        return items_added

    def _generate_batches(self):
        """
        Generates batches of packed sequences.

        The returned generator's iterator will always, when next() is called on it, either:
         - return a valid tuple batch (masked_batch, labels, cu_seq_lens,max_seq_lens)
         - raise StopIteration
        """
        while True:
            retval = self._create_batch()
            if retval is None:
                break
            batch, lst_cu_seq_lens = retval

            assert isinstance(retval, tuple), f"Unexpected {type(retval)=}"
            assert isinstance(retval[0], np.ndarray), f"Unexpected {type(retval[0])=}"
            assert isinstance(retval[1], list), f"Unexpected {type(retval[1])=}"

            cu_seq_lens = [torch.tensor(x, dtype=torch.int32) for x in lst_cu_seq_lens]
            max_seq_lens = [torch.max(x[1:] - x[:-1]).item() for x in cu_seq_lens]
            assert isinstance(cu_seq_lens, list), f"Unexpected {type(cu_seq_lens)=}"
            if self.suppress_masking:
                yieldval = {
                    "input_ids": torch.from_numpy(batch),
                    "labels": None,
                    "cu_seqlens": cu_seq_lens,
                    "max_seqlen": max_seq_lens,
                }
            else:
                (masked_batch, labels) = SequencePacker.mlm_masking(
                    batch, self.mask_prob, self.mask_token_id, self.pad_token_id, self.ignore_token_id, self.np_rng
                )
                yieldval = {
                    "input_ids": torch.from_numpy(masked_batch),
                    "labels": torch.from_numpy(labels),
                    "cu_seqlens": cu_seq_lens,
                    "max_seqlen": max_seq_lens,
                    "attention_mask": torch.from_numpy(np.where(batch == self.pad_token_id, 0, 1)),
                }
                self._token_count += yieldval["attention_mask"].sum().item()
            # # assert isinstance(yieldval[0], torch.Tensor), f"Unexpected {type(yieldval[0])=}"
            # if not self.suppress_masking:
            #     assert isinstance(yieldval[1], torch.Tensor), f"Unexpected {type(yieldval[1])=}"
            # assert isinstance(yieldval[2], list), f"Unexpected {type(yieldval[2])=}"
            # if yieldval[2]:
            #     assert isinstance(yieldval[2][0], torch.Tensor), f"Unexpected {type(yieldval[2][0])=}"
            yield yieldval

    @staticmethod
    def mlm_masking(
        seq: np.ndarray,
        mask_prob: float,
        mask_token: int,
        pad_token: int = -1,
        ignore_index: int = -100,
        np_rng=np.random.default_rng(),
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.

        This is exactly a numpy version of transformers' `DataCollatorForLanguageModeling.torch_mask_tokens`
        https://github.com/huggingface/transformers/blob/main/src/transformers/data/data_collator.py#L827

        It performs masking in a way that produces on expectation the following masked inputs:
         - (1-mask_prob) of the original positions will be untouched.
         - mask_prob * 80%  of the original positions get replaced with a mask token
         - mask_prob * 10%  of the original positions get replaced with a random token
         - mask_prob * 10%  of the original positions also remain untouched.
        This generates the masked_inputs.

        It also generates a labels array, which has ignore tokens in the (1-mask_prob) positions

        These proportions are expectation values since the random transformation is performed
        independently per element. (This is why it is agnostic wrt shape.)

        Args:
          seq (np.ndarray): the input token IDs (e.g., a sequence, or batch of seqs)
          mask_prob (float): probability of initially masking a token, in the first "wave" of masking
          mask_token (int): token to use for masking
          ignore_index (int): the token indicating that position should be ignored during training. We call it `ignore_index` to conform to the API of the cross entropy loss function.

        Returns:
            tuple[np.array,np.array]: (masked_seq, labels)
                masked_seq: the input seq with some tokens replaced by `mask_token`
                labels: the original input seq with non-masked tokens replaced by `ignore_index`
        """
        # Create labels
        labels = np.where(seq == pad_token, ignore_index, seq)

        # Create a single mask
        rand = np_rng.random(seq.shape)

        # Partition the probability space appropriately using a single mask
        # 80% of the time, we mask the token
        mask_mask = rand < mask_prob * 0.8
        # 10% of the time, we replace the token with a random token
        random_mask = (rand >= mask_prob * 0.8) & (rand < mask_prob * 0.9)
        # 10% of the time, we keep the token the same
        keep_mask = (rand >= mask_prob * 0.9) & (rand < mask_prob)

        # We only compute loss over the tokens marked for masking
        labels = np.where(mask_mask | random_mask | keep_mask, labels, ignore_index)

        # Apply masking
        seq = np.where(mask_mask, mask_token, seq)

        # Apply random replacement
        random_words = np_rng.integers(0, np.max(seq) + 1, size=seq.shape)
        seq = np.where(random_mask, random_words, seq)

        return seq, labels

    @abstractmethod
    def _create_batch(self) -> Optional[tuple[np.ndarray, list[list[int]]]]:
        """
        Returns a batch of packed sequences with its cumulative seq length information.

        Or else, returns None if it cannot build a full outgoing batch.

        Must mutate self.buffer to remove the sequences that are packed into the batch.

        Returns:
            (out_batch,cumulative_seq_len):tuple[torch.tensor, list[list[int]]]
            where:
                - out_batch is a tensor of shape (out_batch_size, out_pseq_len);
                - cum_seq_lens is a list of lists, where the outer list is of len out_batch_size,
                    and each inner list is of varying length, and contains the start positions of
                    every seq in the pseq, and the end position of the last seq in the pseq. This end
                    position is necessary to communicate if any padding tokens were added.
        """
        pass


@njit
def find_best_fit(remaining_spaces, seq_len):
    valid_spaces = seq_len <= remaining_spaces
    if np.any(valid_spaces):
        valid_space_sizes = remaining_spaces[valid_spaces]
        best_fit_idx = np.argmin(valid_space_sizes)
        return np.arange(len(remaining_spaces))[valid_spaces][best_fit_idx]
    return -1


class GreedyBestFitSequencePacker(SequencePacker):
    @classmethod
    def from_composer(
        cls,
        src_iterable: Iterable[list[list[int]]],
        batch_size: int = 512,
        micro_batch_size: int = 32,
        max_seq_len: int = 1024,
        buffer_size: int = 5120,
        # token values
        pad_token_id: int = -1,
        mask_token_id: int = 0,
        ignore_token_id: int = -100,
        mask_prob: float = 0.3,
        # transform values
        seed=42,
        suppress_masking=False,
        batch_size_warmup_min_size: Optional[int] = None,
        batch_size_warmup_tokens = None,
        # batch_size_warmup_tokens: Optional[Union[str, Time]] = None,
        world_size: int = 1,
    ) -> "GreedyBestFitSequencePacker":
        if batch_size_warmup_min_size is not None:
            if batch_size_warmup_min_size % micro_batch_size != 0:
                raise ValueError(f"{batch_size_warmup_min_size=} must be a multiple of {micro_batch_size=}")
            batch_size_warmup_min_size = int(batch_size_warmup_min_size / micro_batch_size)
        return cls(
            # input shape
            src_iterable=src_iterable,
            src_batch_size=batch_size,
            src_max_seq_len=max_seq_len,
            # output shape
            out_batch_size=int(batch_size / micro_batch_size),
            out_pseq_len=int(micro_batch_size * max_seq_len),
            # internal
            buffer_size=buffer_size,
            # transformation
            pad_token_id=pad_token_id,
            mask_token_id=mask_token_id,
            ignore_token_id=ignore_token_id,
            mask_prob=mask_prob,
            seed=seed,
            suppress_masking=suppress_masking,
            batch_size_warmup_min_size=batch_size_warmup_min_size,
            batch_size_warmup_tokens=batch_size_warmup_tokens,
            world_size=world_size,
        )

    def _create_batch(self) -> Optional[tuple[np.ndarray, list[list[int]]]]:
        if self.batch_size_scheduler:
            self.out_batch_size = self.batch_size_scheduler(self._token_count)

        batch = np.full(
            (self.out_batch_size, self.out_pseq_len), self.pad_token_id, dtype=np.int64
        )  # the pseqs being constructed
        seq_counts = np.zeros(self.out_batch_size, dtype=np.int32)  # the count of seqs per pseq
        cum_seq_lens = [[0] for _ in range(self.out_batch_size)]
        remaining_spaces = np.full(
            (self.out_batch_size,), self.out_pseq_len, dtype=np.int32
        )  # the space remaining per pseq
        temp_buffer = []

        while True:
            # Check if buffer has more items, and if not replenish
            if not self.buffer:
                items_to_fetch = self.buffer_size - len(temp_buffer)
                items_added = self._fill_buffer(items_to_fetch)
                if items_added == 0:
                    break

            seq = self.buffer.popleft()
            seq_len = len(seq)

            # Find the best fit (smallest space that can accommodate the sequence)
            best_fit_idx = find_best_fit(remaining_spaces, seq_len)
            if best_fit_idx != -1:
                end_pos = self.out_pseq_len - remaining_spaces[best_fit_idx]
                batch[best_fit_idx, end_pos : end_pos + seq_len] = seq
                seq_counts[best_fit_idx] += 1
                remaining_spaces[best_fit_idx] -= seq_len
                cum_seq_lens[best_fit_idx].append(cum_seq_lens[best_fit_idx][-1] + seq_len)
            else:
                # Can't fit the sequence, save for next batch
                temp_buffer.append(seq)

        # Add any sequences we skipped back to the start of the buffer
        self.buffer.extendleft(temp_buffer)

        if np.all(seq_counts > 0):
            self._seqs_emitted += np.sum(seq_counts)
            for x in cum_seq_lens:
                if x[-1] != self.out_pseq_len:
                    x.append(self.out_pseq_len)
            return batch, cum_seq_lens
        else:
            # If we can't form a full batch, we return None to signal the end
            return None


T = TypeVar("T")


class BufferedIterable(Generic[T]):
    def __init__(self, iterable: Iterable[T], buffer_size: int):
        """
        Args:
          - iterable: an object which generates a fresh iterator on iter() and which implements len()
        """
        self.iterable = iterable
        self.buffer_size = buffer_size

    def __iter__(self):
        return BufferedIterator(self.iterable, self.buffer_size)


class BufferedIterator(Generic[T]):
    def __init__(self, iterable: Iterable[T], buffer_size: int):
        self.iterator = iter(iterable)
        self.buffer = deque(maxlen=buffer_size)
        self.buffer_size = buffer_size
        self.lock = threading.Lock()
        self.exhausted = False
        self.filler_thread = threading.Thread(target=self._background_fill, daemon=True)
        self.filler_thread.start()

    def _background_fill(self):
        # Fill up the buffer, whenever possible, in the background
        while not self.exhausted:
            if len(self.buffer) < self.buffer_size:
                try:
                    item = next(self.iterator)
                    with self.lock:
                        self.buffer.append(item)
                except StopIteration:
                    self.exhausted = True
                    break
            else:
                time.sleep(0.01)  # Sleep for a bit to avoid busy waiting

    def __iter__(self):
        return self

    def __next__(self) -> T:
        while True:
            if not self.buffer:
                if self.exhausted:
                    # We've exhausted the iterator and the buffer so we're done
                    raise StopIteration
                else:
                    # The buffer is empty but the iterator is not exhausted yet.
                    # Let's give the filler thread a chance to add items to the buffer
                    time.sleep(0.01)
            else:
                with self.lock:
                    return self.buffer.popleft()


def split_packed_batch(batch: Any, microbatch_size: Union[int, float], padding_tolerance=1.0) -> Sequence:
    # NOTE: Packed sequences are already packed into a microbatch size worth of tokens.
    # So to correctly return a microbatch worth of data, we will simply return each item (i.e. microbatch_size 1)

    num_items = batch["input_ids"].shape[0]
    split_inputs = [x.squeeze() for x in batch["input_ids"].split(1)]
    split_labels = [x.squeeze() for x in batch["labels"].split(1)]
    split_attention_masks = [x.squeeze() for x in batch["attention_mask"].split(1)]
    split_cu_seqlens = batch["cu_seqlens"]

    result = []
    for i in range(num_items):
        attention_mask = split_attention_masks[i]
        padding_amount = 1 - (attention_mask.sum() / len(attention_mask))

        if padding_amount > padding_tolerance:
            last_non_pad = attention_mask.nonzero().max()
            input_ids = split_inputs[i][: last_non_pad + 1]
            labels = split_labels[i][: last_non_pad + 1]
            cu_seqlens = split_cu_seqlens[i][:-1]
            attention_mask = attention_mask[: last_non_pad + 1]
        else:
            input_ids = split_inputs[i]
            labels = split_labels[i]
            cu_seqlens = split_cu_seqlens[i]

        result.append(
            {
                "input_ids": input_ids,
                "labels": labels,
                "cu_seqlens": cu_seqlens,
                "max_seqlen": batch["max_seqlen"][i],
                "attention_mask": attention_mask,
            }
        )

    assert all([x["input_ids"].shape[-1] == y["cu_seqlens"][-1] for x, y in zip(result, result)])
    return result


def get_num_samples_in_packed_batch(batch) -> int:
    # Number of sequences can be inferred from cu_seqlens arrays
    cu_seqlens = batch["cu_seqlens"]
    if isinstance(cu_seqlens, torch.Tensor):
        return cu_seqlens.size()[0] - 1
    elif isinstance(cu_seqlens, list):
        return sum([x.size()[0] - 1 for x in batch["cu_seqlens"]])
    else:
        raise TypeError('Expected a batch with a "cu_seqlens" key of type list or Tensor')

In [None]:
from torch.utils.data import DataLoader
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, BoolTensor

In [None]:
FLASH_ATTN_AVAILABLE = False
try:
    from flash_attn import flash_attn_varlen_func

    FLASH_ATTN_AVAILABLE = True
except ImportError:
    # Flash Attention is not available
    pass

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [None]:
def simple_collate(batch): return batch

## Dataset

**AI Response**

I'll show you a simple example of how to use the SequencePacker, specifically the GreedyBestFitSequencePacker implementation. Let's create a minimal working example:

```python
# First, let's create some sample data
import random

# Create a simple dataset of variable length sequences
class SimpleDataset:
    def __init__(self, num_samples=100, min_len=10, max_len=100, vocab_size=1000):
        self.samples = []
        for _ in range(num_samples):
            seq_len = random.randint(min_len, max_len)
            self.samples.append({"input_ids": [random.randint(1, vocab_size-1) for _ in range(seq_len)]})
    
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

# Create a batch sampler
from torch.utils.data import DataLoader

dataset = SimpleDataset(num_samples=100)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Create the sequence packer
packer = GreedyBestFitSequencePacker(
    src_iterable=dataloader,
    src_batch_size=8,
    src_max_seq_len=100,
    out_batch_size=2,
    out_pseq_len=200,
    buffer_size=32,
    pad_token_id=-1,
    mask_token_id=0,
    ignore_token_id=-100,
    mask_prob=0.15,
    suppress_masking=False
)

# Use the packer
for i, batch in enumerate(packer):
    if i >= 3: break  # Just show first few batches
    
    # Unpack the batch
    input_ids = batch["input_ids"]
    labels = batch["labels"]
    cu_seqlens = batch["cu_seqlens"]
    
    print(f"Batch {i}:")
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Cumulative sequence lengths: {[s.tolist() for s in cu_seqlens]}")
    print(f"  Sequences per packed sequence: {[len(s)-1 for s in cu_seqlens]}")
    print()
```

This example:

1. Creates a simple dataset with variable length sequences
2. Sets up a DataLoader to batch the sequences
3. Creates a GreedyBestFitSequencePacker that:
   - Takes batches of 8 sequences from the DataLoader
   - Produces output batches with 2 packed sequences, each with length 200
   - Uses a buffer size of 32 sequences
   - Applies 15% masking for MLM training
4. Iterates through the packed batches and prints information about each one

The key benefit of the sequence packer is efficiently packing variable-length sequences into fixed-length batches, which maximizes GPU utilization for training language models.

In [None]:
class SimpleDataset:
    def __init__(self, num_samples=100, min_len=10, max_len=100, vocab_size=1000):
        self.samples = []
        for _ in range(num_samples):
            seq_len = random.randint(min_len, max_len)
            self.samples.append({"input_ids": [random.randint(1, vocab_size-1) for _ in range(seq_len)]})
    
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

In [None]:
ds = SimpleDataset(num_samples=5, min_len=10, max_len=50, vocab_size=50)
ds[0]

{'input_ids': [35,
  28,
  22,
  27,
  44,
  40,
  28,
  41,
  14,
  10,
  15,
  26,
  10,
  20,
  32,
  41,
  33,
  33,
  1,
  13,
  14,
  24,
  18]}

In [None]:
len(ds)

5

In [None]:
[len(ds[i]['input_ids']) for i in range(len(ds))]

[23, 14, 49, 32, 45]

In [None]:
dl = DataLoader(ds, batch_size=8, shuffle=True, collate_fn=simple_collate)
batch = next(iter(dl))
print(len(batch))
print([len(item['input_ids']) for item in batch])

5
[23, 49, 32, 14, 45]


In [None]:
larger_ds = SimpleDataset(num_samples=50, min_len=20, max_len=50, vocab_size=1000)
larger_dl = DataLoader(larger_ds, batch_size=8, shuffle=True, collate_fn=simple_collate)
smaller_packer = GreedyBestFitSequencePacker(
    src_iterable=larger_dl,
    src_batch_size=8,
    src_max_seq_len=50,  # Match our max sequence length
    out_batch_size=2,    # Create 2 packed sequences per batch
    out_pseq_len=100,    # Each packed sequence can hold ~2-5 of our sequences
    buffer_size=32,
    pad_token_id=-1,
    mask_token_id=0,
    ignore_token_id=-100,
    mask_prob=0.15,
    suppress_masking=False
)

for i, batch in enumerate(smaller_packer):
    if i >= 2: break
    
    input_ids = batch["input_ids"]
    labels = batch["labels"]
    cu_seqlens = batch["cu_seqlens"]
    
    print(f"Batch {i}:")
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Cumulative sequence lengths: {[s.tolist() for s in cu_seqlens]}")
    print(f"  Sequences per packed sequence: {[len(s)-1 for s in cu_seqlens]}")
    print()

Batch 0:
  Input shape: torch.Size([2, 100])
  Labels shape: torch.Size([2, 100])
  Cumulative sequence lengths: [[0, 45, 93, 100], [0, 24, 68, 91, 100]]
  Sequences per packed sequence: [3, 4]

Batch 1:
  Input shape: torch.Size([2, 100])
  Labels shape: torch.Size([2, 100])
  Cumulative sequence lengths: [[0, 25, 63, 93, 100], [0, 47, 94, 100]]
  Sequences per packed sequence: [4, 3]



In [None]:
batch["input_ids"]

tensor([[251, 958, 979, 679, 380, 418, 804,   0, 425, 512, 598, 856, 256, 137,
         898, 366, 848, 409, 120,   0, 929, 768, 712, 943, 832, 800,  40,   0,
         790, 548, 943, 763,   0, 602,   0,   0, 727, 317, 234, 134, 898, 152,
           0, 569, 519, 371,   0, 283, 125,   0,  17, 854, 305, 371, 794, 508,
         468, 789, 369,  25, 622, 751, 789,   0, 920, 801, 563, 822, 474, 801,
         353, 951,  85, 960, 138, 303, 990, 723, 917, 992,   0, 330,   0, 428,
         884, 145, 278, 498, 445, 965, 261,  -1,  -1,  -1,  -1,  -1,  -1,  -1,
          -1,  -1],
        [401, 127, 558, 191, 158, 893, 490,  91, 397, 629, 262, 323,  47, 762,
         714, 132,  93,   0, 499, 491, 656, 306,   0,  33,   0,   0, 301,   0,
         238, 723, 850,   0,   0, 248,   0, 426,  50, 923, 884, 311,   0,   0,
         448,   0, 361, 694, 744, 382, 209, 244, 529,  51, 908, 555, 981, 445,
         830, 484,   0, 207, 132, 509,   0, 760, 306,   0, 158,   0,   0,  55,
         549, 290, 565, 407,   0

In [None]:
batch["labels"]

tensor([[-100, -100, -100, -100, -100, -100, -100,  435, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100,  279,   89, -100, -100, -100, -100,
         -100, -100, -100,  947, -100, -100, -100, -100,  409, -100,  903,  747,
         -100, -100, -100, -100, -100, -100,  219, -100, -100, -100,  428, -100,
         -100,  596, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   58, -100, -100, -100, -100, -100, -100,  353, -100,
         -100, -100, -100, -100, -100, -100,  793, -100,  733, -100,  680, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100,  824, -100, -100, -100, -100,  252, -100,
          327,  830, -100,  277, -100, -100, -100,  848,   32, -100,  542, -100,
         -100, -100, -100, -100,  418,  685, -100,  681, -100, -100, -100, 

In [None]:
batch["cu_seqlens"]

[tensor([  0,  44,  91, 100], dtype=torch.int32),
 tensor([  0,  45,  87, 100], dtype=torch.int32)]

**AI Prompt**

What are input_ids, labels, and cu_seqlens?

**AI Response**

Looking at the output from the SequencePacker, let me explain each of these components:

1. **input_ids**: 
   - This is a tensor containing the packed token sequences
   - Shape: [out_batch_size, out_pseq_len] (in this case [2, 100])
   - Contains the actual token IDs, with some tokens masked (replaced with mask_token_id=0)
   - Padding tokens are represented by pad_token_id=-1

2. **labels**:
   - This is a tensor used for the masked language modeling (MLM) training objective
   - Shape: Same as input_ids [2, 100]
   - Contains:
     - The original token IDs at positions that were masked
     - ignore_token_id (-100) at positions that weren't masked
     - The model is trained to predict only the masked tokens during training

3. **cu_seqlens** (cumulative sequence lengths):
   - A list of tensors (one per packed sequence) that tells you where each original sequence begins and ends within the packed sequence
   - Each tensor contains indices that mark the boundaries between sequences
   - For example, [0, 40, 83, 100] means:
     - First sequence: positions 0-39 (40 tokens)
     - Second sequence: positions 40-82 (43 tokens)
     - Third sequence: positions 83-99 (17 tokens)
     - The last value (100) is the total length of the packed sequence

The packer efficiently combines multiple variable-length sequences into fixed-length packed sequences, keeping track of where each original sequence begins and ends. This improves training efficiency by reducing padding waste.

## Autoregressive

**AI Prompt**

I want to use sequence packer for autoregressive model.

**AI Response**

For using the SequencePacker with an autoregressive model, you'll need to modify a few things. The current implementation is designed for masked language modeling (MLM) rather than autoregressive training.

Here's how you can adapt it for autoregressive models:

```python
class AutoregressiveSequencePacker(GreedyBestFitSequencePacker):
    def __init__(self, *args, **kwargs):
        kwargs['suppress_masking'] = True
        super().__init__(*args, **kwargs)
    
    def _generate_batches(self):
        """Override to create shifted inputs and labels for autoregressive training"""
        for batch_dict in super()._generate_batches():
            input_ids = batch_dict["input_ids"]
            
            # For autoregressive models:
            # - inputs are all tokens except the last one
            # - labels are all tokens except the first one
            inputs = input_ids.clone()
            labels = input_ids.clone()
            
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = (inputs != self.pad_token_id).long()
            
            # For each packed sequence, create proper labels
            for i, cu_seqlens in enumerate(batch_dict["cu_seqlens"]):
                for j in range(len(cu_seqlens) - 1):
                    start, end = cu_seqlens[j].item(), cu_seqlens[j+1].item()
                    
                    # Set labels to ignore_token_id at sequence boundaries
                    if j < len(cu_seqlens) - 2:  # Not the last sequence
                        labels[i, end-1] = self.ignore_token_id
            
            # Create the causal mask for autoregressive attention
            # Each token can only attend to itself and previous tokens
            seq_len = inputs.size(1)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            
            yield {
                "input_ids": inputs,
                "labels": labels,
                "attention_mask": attention_mask,
                "causal_mask": causal_mask,
                "cu_seqlens": batch_dict["cu_seqlens"],
                "max_seqlen": batch_dict["max_seqlen"]
            }
```

And here's how you'd use it:

```python
autoregressive_packer = AutoregressiveSequencePacker(
    src_iterable=larger_dl,
    src_batch_size=8,
    src_max_seq_len=50,
    out_batch_size=2,
    out_pseq_len=100,
    buffer_size=32,
    pad_token_id=-1,
    ignore_token_id=-100
)

for i, batch in enumerate(autoregressive_packer):
    if i >= 1: break
    
    print(f"Batch {i}:")
    print(f"  Input shape: {batch['input_ids'].shape}")
    print(f"  Labels shape: {batch['labels'].shape}")
    print(f"  Attention mask shape: {batch['attention_mask'].shape}")
    print(f"  Causal mask shape: {batch['causal_mask'].shape}")
    print(f"  Cumulative sequence lengths: {[s.tolist() for s in batch['cu_seqlens']]}")
```

This implementation:
1. Inherits from GreedyBestFitSequencePacker but turns off masking
2. Overrides _generate_batches to create proper inputs and labels for autoregressive training
3. Adds a causal attention mask to ensure tokens only attend to previous tokens
4. Sets labels at sequence boundaries to ignore_token_id to avoid the model learning to predict across sequence boundaries

In [None]:
class AutoregressiveSequencePacker(GreedyBestFitSequencePacker):
    def __init__(self, *args, **kwargs):
        kwargs['suppress_masking'] = True
        super().__init__(*args, **kwargs)
    
    def _generate_batches(self):
        """Override to create shifted inputs and labels for autoregressive training"""
        for batch_dict in super()._generate_batches():
            input_ids = batch_dict["input_ids"]
            
            # Create the shifted inputs (tokens 0 to n-1) and labels (tokens 1 to n)
            inputs = input_ids.clone()
            # For labels, we shift by 1 (each position predicts the next token)
            labels = torch.roll(input_ids, shifts=-1, dims=1).clone()
            
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = (inputs != self.pad_token_id).long()
            
            # For each packed sequence, handle sequence boundaries
            for i, cu_seqlens in enumerate(batch_dict["cu_seqlens"]):
                for j in range(len(cu_seqlens) - 1):
                    start, end = cu_seqlens[j].item(), cu_seqlens[j+1].item()
                    
                    # Set the label at the last position of each sequence to ignore_token_id
                    # We don't want to predict across sequence boundaries
                    labels[i, end-1] = self.ignore_token_id
            
            # Create the causal mask for autoregressive attention
            seq_len = inputs.size(1)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            
            yield {
                "input_ids": inputs,
                "labels": labels,
                "attention_mask": attention_mask,
                "causal_mask": causal_mask,
                "cu_seqlens": batch_dict["cu_seqlens"],
                "max_seqlen": batch_dict["max_seqlen"]
            }

In [None]:
autoregressive_packer = AutoregressiveSequencePacker(
    src_iterable=larger_dl,
    src_batch_size=8,
    src_max_seq_len=50,
    out_batch_size=2,
    out_pseq_len=100,
    buffer_size=32,
    pad_token_id=-1,
    ignore_token_id=-100
)

for i, batch in enumerate(autoregressive_packer):
    if i >= 1: break
    
    print(f"Batch {i}:")
    print(f"  Input shape: {batch['input_ids'].shape}")
    print(f"  Labels shape: {batch['labels'].shape}")
    print(f"  Attention mask shape: {batch['attention_mask'].shape}")
    print(f"  Causal mask shape: {batch['causal_mask'].shape}")
    print(f"  Cumulative sequence lengths: {[s.tolist() for s in batch['cu_seqlens']]}")

Batch 0:
  Input shape: torch.Size([2, 100])
  Labels shape: torch.Size([2, 100])
  Attention mask shape: torch.Size([2, 100])
  Causal mask shape: torch.Size([1, 1, 100, 100])
  Cumulative sequence lengths: [[0, 48, 77, 98, 100], [0, 43, 85, 100]]


In [None]:
batch['input_ids']

tensor([[186, 503, 198, 795, 831, 221, 575, 313, 627, 718, 721, 971, 570,  35,
         314, 600, 650, 394,  95, 132, 274,  72, 201, 493, 319, 141, 540, 761,
          28, 608, 131, 764, 830, 617, 272, 329, 827, 950, 731,  44, 626, 761,
         404,  16, 241, 535, 577, 360, 323, 507, 461, 183, 812, 960, 907, 982,
         314, 513, 151, 724, 578, 711, 111,  18, 701, 116, 132,  70, 447, 712,
         457, 127, 324, 932, 834, 441, 795, 831, 223, 178,  69, 256, 551, 144,
         311, 908, 634, 496, 848, 903, 210, 721, 576, 408, 764, 810, 825, 138,
         858, 959],
        [493, 216, 382, 583, 984, 896, 546, 248,  62, 904, 400, 712, 572, 850,
         358, 106, 997,  90, 482,  62, 100, 786, 367, 610, 301, 544,  51, 611,
          22,  20, 700, 785, 899, 839, 927,  43, 568, 274, 121, 532,  52, 104,
         128, 653, 149,  56, 145, 506, 776, 701, 831, 347, 158, 270, 460, 327,
         705, 790, 274, 265, 993, 107, 657, 945, 681, 255, 336, 343, 194, 654,
          72, 658, 761, 913, 908

In [None]:
batch['labels']

tensor([[ 503,  198,  795,  831,  221,  575,  313,  627,  718,  721,  971,  570,
           35,  314,  600,  650,  394,   95,  132,  274,   72,  201,  493,  319,
          141,  540,  761,   28,  608,  131,  764,  830,  617, -100,  329,  827,
          950,  731,   44,  626,  761,  404,   16,  241,  535,  577,  360,  323,
          507,  461,  183,  812,  960,  907,  982,  314,  513,  151,  724,  578,
          711,  111,   18,  701,  116,  132,   70,  447,  712,  457,  127,  324,
          932,  834,  441,  795,  831, -100,  178,   69,  256,  551,  144,  311,
          908,  634,  496,  848,  903,  210,  721,  576,  408,  764,  810,  825,
          138,  858,  959, -100],
        [ 216,  382,  583,  984,  896,  546,  248,   62,  904,  400,  712,  572,
          850,  358,  106,  997,   90,  482,   62,  100,  786,  367,  610,  301,
          544,   51,  611,   22,   20,  700,  785,  899,  839,  927,   43, -100,
          274,  121,  532,   52,  104,  128,  653,  149,   56,  145,  506, 

In [None]:
batch['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]])

In [None]:
batch['causal_mask']

tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])

**AI Prompt**

Does it look good to you?

**AI Response**

Yes, the implementation looks very good! The AutoregressiveSequencePacker is correctly configured for autoregressive training:

1. **Input/Label Shifting**: The labels are properly shifted (using `torch.roll`) so that each position predicts the next token.

2. **Sequence Boundary Handling**: The code correctly identifies the end of each sequence (using `cu_seqlens`) and sets the label at those positions to `ignore_token_id` (-100). This prevents the model from learning to predict across sequence boundaries.

3. **Attention Masking**: Two important masks are created:
   - `attention_mask`: Distinguishes real tokens (1) from padding tokens (0)
   - `causal_mask`: A lower triangular matrix ensuring each token can only attend to itself and previous tokens

4. **Packed Sequences**: The original packing functionality is preserved, efficiently combining multiple sequences into fixed-length batches.

The output shows everything is working correctly:
- The input and label tensors have the expected shapes
- The cumulative sequence lengths show we've packed multiple sequences (4 in the first packed sequence, 4 in the second)
- The labels are properly shifted, with `-100` at sequence boundaries
- The attention mask correctly identifies non-padding tokens
- The causal mask has the expected lower triangular structure

This implementation will work well for training autoregressive models like GPT with packed sequences, which should improve training efficiency by reducing padding waste.

**AI Prompt**

How do I use this now?

**AI Response**

Here's how to use the AutoregressiveSequencePacker with a language model like GPT:

```python
# 1. Create your dataset and dataloader
train_dataset = SimpleDataset(num_samples=10000, min_len=20, max_len=200)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=16, 
    shuffle=True, 
    collate_fn=simple_collate
)

# 2. Create the sequence packer
packer = AutoregressiveSequencePacker(
    src_iterable=train_dataloader,
    src_batch_size=16,
    src_max_seq_len=200,
    out_batch_size=4,
    out_pseq_len=512,
    buffer_size=64,
    pad_token_id=-1,
    ignore_token_id=-100
)

# 3. Create your model (example with a simple GPT-style model)
from transformers import GPT2Config, GPT2LMHeadModel

model_config = GPT2Config(
    vocab_size=1000,
    n_positions=512,
    n_ctx=512,
    n_embd=768,
    n_layer=6,
    n_head=12
)
model = GPT2LMHeadModel(model_config)

# 4. Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(3):
    for batch in packer:
        # Move batch to device
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=inputs,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Loss: {loss.item()}")
```

For inference with the trained model:

```python
def generate_text(model, prompt, max_length=50):
    # Convert prompt to tensor
    input_ids = torch.tensor([prompt], dtype=torch.long).to(device)
    
    # Generate
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            do_sample=True,
            top_p=0.95,
            temperature=0.8
        )
    
    return output[0]

# Example prompt (list of token IDs)
prompt = [10, 25, 36, 42]
generated = generate_text(model, prompt)
print(generated)
```

Key points:
1. The packer efficiently combines sequences for better GPU utilization
2. The model receives properly formatted inputs with causal masking
3. The labels are shifted and have sequence boundaries masked
4. The loss is automatically computed only on valid positions (ignoring padding and sequence boundaries)

In [None]:
# 1. Create your dataset and dataloader
train_dataset = SimpleDataset(num_samples=10000, min_len=20, max_len=200)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=16, 
    shuffle=True, 
    collate_fn=simple_collate
)

# 2. Create the sequence packer
packer = AutoregressiveSequencePacker(
    src_iterable=train_dataloader,
    src_batch_size=16,
    src_max_seq_len=200,
    out_batch_size=4,
    out_pseq_len=512,
    buffer_size=64,
    pad_token_id=-1,
    ignore_token_id=-100
)

# 3. Create your model (example with a simple GPT-style model)
from transformers import GPT2Config, GPT2LMHeadModel

model_config = GPT2Config(
    vocab_size=1000,
    n_positions=512,
    n_ctx=512,
    n_embd=768,
    n_layer=6,
    n_head=12
)
model = GPT2LMHeadModel(model_config)

# 4. Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(3):
    for batch in packer:
        # Move batch to device
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=inputs,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Loss: {loss.item()}")


/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [439,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import pdb; pdb.pm()

In [None]:
def generate_text(model, prompt, max_length=50):
    # Convert prompt to tensor
    input_ids = torch.tensor([prompt], dtype=torch.long).to(device)
    
    # Generate
    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            do_sample=True,
            top_p=0.95,
            temperature=0.8
        )
    
    return output[0]

# Example prompt (list of token IDs)
prompt = [10, 25, 36, 42]
generated = generate_text(model, prompt)
print(generated)


In [None]:
# Print the maximum token ID in your data
max_token = 0
for batch in packer:
    curr_max = batch["input_ids"].max().item()
    if curr_max > max_token:
        max_token = curr_max
    break
print(f"Maximum token ID in data: {max_token}")

In [None]:
# Check for negative values besides the padding token
for batch in packer:
    # Find all unique values
    unique_values = torch.unique(batch["input_ids"])
    print(f"Unique values in input_ids: {unique_values}")
    
    # Check if there are negative values other than pad_token_id
    neg_values = unique_values[unique_values < 0]
    print(f"Negative values: {neg_values}")
    
    # Count occurrences of each value
    for val in neg_values:
        count = (batch["input_ids"] == val).sum().item()
        print(f"Value {val} occurs {count} times")
    
    break


In [None]:
# Try explicitly telling the model to ignore padding in attention mask
for batch in packer:
    # Make sure padding tokens have 0 in attention mask
    attention_mask = (batch["input_ids"] != -1).long()
    
    # Convert -1 padding to 0 (or any valid token ID that's safe to use)
    inputs = batch["input_ids"].clone()
    inputs[inputs == -1] = 0
    
    # Fix labels too
    labels = batch["labels"].clone()
    labels[labels == -1] = -100  # -100 is the standard ignore index for CrossEntropyLoss
    
    # Try with the fixed inputs
    outputs = model(
        input_ids=inputs,
        attention_mask=attention_mask,
        labels=labels
    )
    
    break


In [None]:
# Try without using the custom causal mask
for batch in packer:
    inputs = batch["input_ids"].clone()
    inputs[inputs == -1] = 0
    
    # GPT models create their own causal mask internally
    outputs = model(
        input_ids=inputs,
        attention_mask=batch["attention_mask"],
        labels=batch["labels"].clone().masked_fill_(batch["labels"] == -1, -100)
    )
    
    break


In [None]:
# Inspect the model's embedding parameters
print(f"Model embedding weight shape: {model.transformer.wte.weight.shape}")
print(f"Model embedding weight indices: min={model.transformer.wte.weight.shape[0]-1}, max={0}")


In [None]:
class AutoregressiveSequencePacker(GreedyBestFitSequencePacker):
    def __init__(self, *args, **kwargs):
        kwargs['suppress_masking'] = True
        super().__init__(*args, **kwargs)
    
    def _generate_batches(self):
        """Override to create shifted inputs and labels for autoregressive training"""
        for batch_dict in super()._generate_batches():
            input_ids = batch_dict["input_ids"]
            
            # Create the shifted inputs (tokens 0 to n-1) and labels (tokens 1 to n)
            inputs = input_ids.clone()
            # For labels, we shift by 1 (each position predicts the next token)
            labels = torch.roll(input_ids, shifts=-1, dims=1).clone()
            
            # Handle padding: replace -1 with 0 in inputs
            inputs[inputs == self.pad_token_id] = 0
            
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = (input_ids != self.pad_token_id).long()
            
            # For each packed sequence, handle sequence boundaries
            for i, cu_seqlens in enumerate(batch_dict["cu_seqlens"]):
                for j in range(len(cu_seqlens) - 1):
                    start, end = cu_seqlens[j].item(), cu_seqlens[j+1].item()
                    
                    # Set the label at the last position of each sequence to ignore_token_id
                    # We don't want to predict across sequence boundaries
                    labels[i, end-1] = self.ignore_token_id
            
            # Handle padding in labels: replace -1 with ignore_token_id
            labels[labels == self.pad_token_id] = self.ignore_token_id
            
            yield {
                "input_ids": inputs,
                "labels": labels,
                "attention_mask": attention_mask,
                "cu_seqlens": batch_dict["cu_seqlens"],
                "max_seqlen": batch_dict["max_seqlen"]
            }


In [None]:
# Training loop
model.train()
for epoch in range(3):
    for batch in packer:
        # Move batch to device
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Forward pass - no need for additional conversions
        outputs = model(
            input_ids=inputs,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Loss: {loss.item()}")


In [None]:
# Debug the exact values that might be causing issues
for batch in packer:
    inputs = batch["input_ids"].clone()
    
    # Replace pad tokens with 0
    inputs[inputs == packer.pad_token_id] = 0
    
    # Check for any values >= vocab_size
    max_val = inputs.max().item()
    print(f"Maximum value in inputs after replacing padding: {max_val}")
    
    # Check if there are any values outside the vocabulary range
    invalid_indices = (inputs >= 1000).nonzero()
    if len(invalid_indices) > 0:
        print(f"Found {len(invalid_indices)} invalid indices")
        print(f"Invalid values: {inputs[invalid_indices[:, 0], invalid_indices[:, 1]]}")
    
    # Try explicitly setting the range of values
    valid_inputs = inputs.clamp(0, 999)
    
    # Check if there are still any issues
    try:
        outputs = model(
            input_ids=valid_inputs.to(device),
            attention_mask=batch["attention_mask"].to(device),
            labels=batch["labels"].clamp(-100, 999).to(device)
        )
        print("Success with clamped values!")
    except Exception as e:
        print(f"Still encountering error: {e}")
    
    break


In [None]:
# Create a very simple input to test the model directly
test_input = torch.randint(0, 1000, (2, 10)).to(device)
test_labels = torch.randint(0, 1000, (2, 10)).to(device)
test_labels[:, -1] = -100  # Ignore last token
test_mask = torch.ones_like(test_input).to(device)

try:
    test_output = model(
        input_ids=test_input,
        attention_mask=test_mask,
        labels=test_labels
    )
    print(f"Test successful! Loss: {test_output.loss.item()}")
except Exception as e:
    print(f"Test failed: {e}")


In [None]:
class FixedAutoregressive(GreedyBestFitSequencePacker):
    def __init__(self, *args, vocab_size=1000, **kwargs):
        self.vocab_size = vocab_size
        kwargs['suppress_masking'] = True
        super().__init__(*args, **kwargs)
    
    def _generate_batches(self):
        for batch_dict in super()._generate_batches():
            # Get the input_ids and create a valid version
            input_ids = batch_dict["input_ids"]
            valid_inputs = input_ids.clone()
            valid_inputs[valid_inputs == self.pad_token_id] = 0
            
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = (input_ids != self.pad_token_id).long()
            
            # Create shifted labels for autoregressive prediction
            labels = torch.roll(valid_inputs, shifts=-1, dims=1).clone()
            
            # Set padding positions in labels to -100 (not -1)
            padding_mask = (input_ids == self.pad_token_id)
            labels[padding_mask] = -100
            
            # Handle sequence boundaries
            for i, cu_seqlens in enumerate(batch_dict["cu_seqlens"]):
                for j in range(len(cu_seqlens) - 1):
                    end = cu_seqlens[j+1].item()
                    if end > 0 and end < labels.size(1):
                        labels[i, end-1] = -100  # Don't predict across sequence boundaries
            
            # Double-check no -1 values remain in labels
            labels[labels == -1] = -100
            
            yield {
                "input_ids": valid_inputs,
                "labels": labels,
                "attention_mask": attention_mask,
                "cu_seqlens": batch_dict["cu_seqlens"],
                "max_seqlen": batch_dict["max_seqlen"]
            }


In [None]:
def verify_batch(batch, vocab_size=1000):
    """Check for any values that might cause issues with the model"""
    inputs = batch["input_ids"]
    labels = batch["labels"]
    
    # Check input values
    if inputs.min().item() < 0:
        print(f"Warning: inputs contain values < 0: {inputs.min().item()}")
    if inputs.max().item() >= vocab_size:
        print(f"Warning: inputs contain values >= vocab_size: {inputs.max().item()}")
    
    # Check label values
    valid_label_mask = (labels != -100)
    if valid_label_mask.sum() > 0:
        valid_labels = labels[valid_label_mask]
        if valid_labels.min().item() < 0:
            print(f"Warning: labels contain values < 0 (excluding -100): {valid_labels.min().item()}")
        if valid_labels.max().item() >= vocab_size:
            print(f"Warning: labels contain values >= vocab_size: {valid_labels.max().item()}")
    
    # Check for -1 values
    if (inputs == -1).any():
        print(f"Warning: inputs contain -1 values: {(inputs == -1).sum().item()}")
    if (labels == -1).any():
        print(f"Warning: labels contain -1 values: {(labels == -1).sum().item()}")
    
    return "All checks passed!" if not (
        inputs.min().item() < 0 or 
        inputs.max().item() >= vocab_size or 
        (valid_label_mask.sum() > 0 and (valid_labels.min().item() < 0 or valid_labels.max().item() >= vocab_size)) or
        (inputs == -1).any() or 
        (labels == -1).any()
    ) else "Issues found!"


In [None]:
verify_batch(batch)

In [None]:
# Fix the current batch
def fix_batch(batch, vocab_size=1000):
    """Fix any problematic values in the batch"""
    # Create fixed copies
    fixed_inputs = batch["input_ids"].clone()
    fixed_labels = batch["labels"].clone()
    
    # Replace -1 with 0 in inputs
    fixed_inputs[fixed_inputs == -1] = 0
    
    # Replace -1 with -100 in labels
    fixed_labels[fixed_labels == -1] = -100
    
    # Ensure all values are in range
    fixed_inputs = fixed_inputs.clamp(0, vocab_size-1)
    
    # For labels, we only clamp the non-ignore values
    valid_mask = (fixed_labels != -100)
    if valid_mask.any():
        fixed_labels[valid_mask] = fixed_labels[valid_mask].clamp(0, vocab_size-1)
    
    # Return the fixed batch
    return {
        "input_ids": fixed_inputs,
        "labels": fixed_labels,
        "attention_mask": batch["attention_mask"],
        "cu_seqlens": batch["cu_seqlens"],
        "max_seqlen": batch["max_seqlen"]
    }

# Try with the fixed batch
fixed = fix_batch(batch)
verify_batch(fixed)

# Try running the model with the fixed batch
try:
    outputs = model(
        input_ids=fixed["input_ids"].to(device),
        attention_mask=fixed["attention_mask"].to(device),
        labels=fixed["labels"].to(device)
    )
    print(f"Success! Loss: {outputs.loss.item()}")
except Exception as e:
    print(f"Error: {e}")


In [None]:
class AutoregressiveSequencePacker(GreedyBestFitSequencePacker):
    def __init__(self, *args, vocab_size=1000, **kwargs):
        self.vocab_size = vocab_size
        kwargs['suppress_masking'] = True
        super().__init__(*args, **kwargs)
    
    def _generate_batches(self):
        for batch_dict in super()._generate_batches():
            # Get the input_ids and create a valid version
            input_ids = batch_dict["input_ids"]
            
            # Replace padding with 0 and ensure all values are in range
            valid_inputs = input_ids.clone()
            valid_inputs[valid_inputs == self.pad_token_id] = 0
            valid_inputs = valid_inputs.clamp(0, self.vocab_size-1)
            
            # Create attention mask (1 for tokens, 0 for padding)
            attention_mask = (input_ids != self.pad_token_id).long()
            
            # Create shifted labels for autoregressive prediction
            labels = torch.roll(valid_inputs, shifts=-1, dims=1).clone()
            
            # Mark padding positions with ignore_token_id
            labels[input_ids == self.pad_token_id] = self.ignore_token_id
            
            # Handle sequence boundaries
            for i, cu_seqlens in enumerate(batch_dict["cu_seqlens"]):
                for j in range(len(cu_seqlens) - 1):
                    end = cu_seqlens[j+1].item()
                    if end > 0 and end < labels.size(1):
                        labels[i, end-1] = self.ignore_token_id
            
            # Final safety check - ensure no -1 values remain
            labels[labels == -1] = self.ignore_token_id
            
            yield {
                "input_ids": valid_inputs,
                "labels": labels,
                "attention_mask": attention_mask,
                "cu_seqlens": batch_dict["cu_seqlens"],
                "max_seqlen": batch_dict["max_seqlen"]
            }


In [None]:
# Training loop
model.train()
for epoch in range(3):
    for batch in packer:
        batch = fix_batch(batch)
        # Move batch to device
        inputs = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Forward pass - no need for additional conversions
        outputs = model(
            input_ids=inputs,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Loss: {loss.item()}")


## Sqeunce packer + FA

In [None]:
class FlashCausalAttentionBlock(nn.Module):
    """
    Attention block implementing multi-head causal (masked) attention using
    Flash Attention.
    """

    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        dropout: float = 0.0,
    ):
        """
        Initialize the causal attention block with Flash Attention implementation.

        Args:
            hidden_dim: Dimension of the input and output features
            num_heads: Number of attention heads
            dropout: Output dropout probability (0.0 means no dropout)

        Note:
            - Make sure to check that hidden_dim is divisible by num_heads
            - Check if Flash Attention is available (FLASH_ATTN_AVAILABLE)
            - You'll need to create linear (projection) layers for query, key, and value
            - Don't forget the output linear (projection) layer
            - Create an output dropout layer
        """
        super().__init__()
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, cu_seqlens: Tensor, max_seqlen: int) -> Tensor:
        """
        Forward pass.

        Args:
            x: Input tensor of shape [total_seq_len, hidden_dim].
            cu_seqlens: Cumulative sequence lengths tensor of shape [batch_size + 1]
                    Used instead of an attention mask for both masking and
                    variable-length sequences. Example:
                        cu_seqlens = torch.tensor([0, 10, 30, 60])
                    This means there are three sequences in the batch:
                        - First sequence has 10 tokens
                        - Second sequence has 20 tokens
                        - Third sequence has 30 tokens
            max_seqlen: Maximum sequence length in the batch. In the example above,
                        the maximum sequence length is 30.

        Returns:
            Tensor of shape [total_seq_len, hidden_dim] after attention.
        """
        if not FLASH_ATTN_AVAILABLE:
            raise ImportError("Flash Attention is not available. Please install it with `pip install flash-attn`")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)

        k_reshaped = k.view(total_seq_len, self.num_heads, self.head_dim)
        q_reshaped = q.view(total_seq_len, self.num_heads, self.head_dim)
        v_reshaped = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        output = flash_attn_varlen_func(
            q_reshaped,
            k_reshaped,
            v_reshaped,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )

        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))

In [None]:
class SimpleGPTModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_heads, num_layers, max_seq_len):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
        
        self.layers = nn.ModuleList([
            FlashCausalAttentionBlock(hidden_dim, num_heads) 
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, input_ids, cu_seqlens, max_seqlen):
        total_seq_len = input_ids.size(0)
        
        x = self.token_embedding(input_ids)
        
        positions = torch.arange(total_seq_len, device=input_ids.device)
        offsets = torch.zeros_like(positions)
        for i in range(len(cu_seqlens) - 1):
            start_idx, end_idx = cu_seqlens[i], cu_seqlens[i+1]
            offsets[start_idx:end_idx] = cu_seqlens[i]
        
        seq_positions = positions - offsets
        x = x + self.position_embedding(seq_positions)
        
        for layer in self.layers:
            x = layer(x, cu_seqlens, max_seqlen)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits

In [None]:
class AutoregressiveDataset:
    def __init__(self, num_samples=100, min_len=10, max_len=100, vocab_size=1000):
        self.samples = []
        for _ in range(num_samples):
            seq_len = random.randint(min_len, max_len)
            tokens = [random.randint(1, vocab_size-1) for _ in range(seq_len)]
            # For autoregressive training, input is tokens[:-1], target is tokens[1:]
            self.samples.append({"input_ids": tokens})
    
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

In [None]:
gpt_dataset = AutoregressiveDataset(num_samples=100, vocab_size=1000)
gpt_dataloader = DataLoader(gpt_dataset, batch_size=8, shuffle=True, collate_fn=simple_collate)

gpt_packer = GreedyBestFitSequencePacker(
    src_iterable=gpt_dataloader,
    src_batch_size=8,
    src_max_seq_len=100,
    out_batch_size=1,
    out_pseq_len=200,
    buffer_size=32,
    pad_token_id=0,
    mask_token_id=0,
    ignore_token_id=-100,
    mask_prob=0.0,  # Not used with suppress_masking=True
    suppress_masking=True  # Important for autoregressive models
)

In [None]:
# Initialize model
model = SimpleGPTModel(vocab_size=1000, hidden_dim=256, num_heads=8, num_layers=4, max_seq_len=100).to(device)

# Training loop
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
    for batch in gpt_packer:
        input_ids = batch["input_ids"]
        cu_seqlens = batch["cu_seqlens"]
        max_seqlen = batch["max_seqlen"]

        # For autoregressive training:
        # - Inputs are all tokens except the last one in each sequence
        # - Targets are all tokens except the first one in each sequence

        # This requires careful handling with packed sequences
        # We can use cu_seqlens to identify sequence boundaries

        # Get logits from model
        logits = model(input_ids.squeeze().to(device), cu_seqlens[0].to(device), max_seqlen[0])

        # Compute loss (autoregressive next-token prediction)
        # This is complex with packed sequences - you need to shift targets
        # and handle sequence boundaries carefully

        targets = torch.roll(input_ids, shifts=-1)
        loss = F.cross_entropy(logits, targets.squeeze().to(device))
        print(loss)

tensor(7.0602, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0452, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0508, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0314, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1291, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0477, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0781, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0338, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1039, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0864, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0139, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1502, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1053, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1514, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1259, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.0941, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1477, device='cuda:0', grad_fn=

In [None]:
import pdb; pdb.pm()

> [32m/home/galopy/miniforge3/envs/fromscratch/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py[39m([92m170[39m)[36m_flash_attn_varlen_forward[39m[34m()[39m
[32m    168[39m ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
[32m    169[39m     q, k, v = [maybe_contiguous(x) [38;5;28;01mfor[39;00m x [38;5;28;01min[39;00m (q, k, v)]
[32m--> 170[39m     out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
[32m    171[39m         q,
[32m    172[39m         k,

ipdb> p q
tensor([[[ 1.0527, -0.1959, -0.8595,  ..., -0.8361,  0.3578,  1.2478],
         [ 0.5617,  0.1539, -0.3003,  ..., -2.3932, -0.9189, -0.6507],
         [-1.4578, -0.8788, -0.8945,  ...,  0.5563, -0.5626,  1.1726],
         ...,
         [ 0.2234, -1.4744, -0.7962,  ..., -1.7971, -0.6704,  0.3328],
         [-1.0220, -1.0418, -1.0343,  ..., -0.7659,  0.2398,  1.7561],
         [ 0.1045,  0.7630,  0.2812,  ..., -0.3854, -0.1902, -0.4510]],

        [[-0.2990,  

### Rope

In [None]:
def precompute_freqs_cis(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(x, freqs_cis):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0)
    x_rotated = x_complex * freqs_cis
    x_out = torch.view_as_real(x_rotated).reshape(*x.shape)
    return x_out.type_as(x)

In [None]:
precompute_freqs_cis(self.head_dim, max_seq_len)

In [None]:
# Create dataset and dataloader
gpt_dataset = AutoregressiveDataset(num_samples=100, vocab_size=1000)
gpt_dataloader = DataLoader(gpt_dataset, batch_size=8, shuffle=True, collate_fn=simple_collate)

# Create sequence packer
gpt_packer = GreedyBestFitSequencePacker(
    src_iterable=gpt_dataloader,
    src_batch_size=8,
    src_max_seq_len=100,
    out_batch_size=1,
    out_pseq_len=200,
    buffer_size=32,
    pad_token_id=0,
    mask_token_id=0,
    ignore_token_id=-100,
    mask_prob=0.0,  # Not used with suppress_masking=True
    suppress_masking=True  # Important for autoregressive models
)

In [None]:
class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        # Precompute RoPE frequencies
        self.register_buffer("freqs_cis", precompute_freqs_cis(self.head_dim, max_seq_len))
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Create position indices that reset at sequence boundaries
        positions = torch.arange(total_seq_len, device=x.device)
        seq_positions = positions.clone()
        for i in range(len(cu_seqlens)-1):
            start, end = cu_seqlens[i], cu_seqlens[i+1]
            seq_positions[start:end] = torch.arange(end-start, device=x.device)
            
        # Apply RoPE to q and k
        freqs_cis = self.freqs_cis[seq_positions]
        q = apply_rotary_emb(q, freqs_cis)
        k = apply_rotary_emb(k, freqs_cis)
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))

In [None]:
class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.max_seq_len = max_seq_len
        # Precompute RoPE frequencies
        self.register_buffer("freqs_cis", precompute_freqs_cis(self.head_dim, max_seq_len))
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Create position indices that reset at sequence boundaries
        positions = torch.arange(total_seq_len, device=x.device)
        seq_positions = torch.zeros_like(positions)
        
        # Compute positions within each sequence, ensuring they don't exceed max_seq_len
        for i in range(len(cu_seqlens)-1):
            start, end = cu_seqlens[i], cu_seqlens[i+1]
            seq_len = end - start
            seq_positions[start:end] = torch.arange(seq_len, device=x.device)
        
        # Ensure positions don't exceed our precomputed table
        seq_positions = seq_positions.clamp(0, self.max_seq_len - 1)
        
        # Apply RoPE to q and k
        freqs_cis = self.freqs_cis[seq_positions]
        q = apply_rotary_emb(q, freqs_cis)
        k = apply_rotary_emb(k, freqs_cis)
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))


In [None]:
def precompute_freqs_cis(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(x, freqs_cis):
    # x shape: [total_seq_len, num_heads, head_dim]
    # freqs_cis shape: [total_seq_len, head_dim//2]
    
    # Make sure head_dim is even
    head_dim = x.shape[-1]
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"
    
    # Reshape x for complex view
    x_reshaped = x.float().reshape(x.shape[0], x.shape[1], head_dim//2, 2)
    x_complex = torch.view_as_complex(x_reshaped)
    
    # Reshape freqs_cis to match x_complex for broadcasting
    # We need to add dimensions for the head axis
    freqs_cis = freqs_cis.unsqueeze(1)  # [total_seq_len, 1, head_dim//2]
    
    # Apply the rotation
    x_rotated = x_complex * freqs_cis
    
    # Convert back to real and reshape to original dimensions
    x_out = torch.view_as_real(x_rotated).reshape(x.shape)
    return x_out.type_as(x)

class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        assert self.head_dim % 2 == 0, "head_dim must be even for RoPE"
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.max_seq_len = max_seq_len
        # Precompute RoPE frequencies - note the head_dim//2 for complex numbers
        self.register_buffer("freqs_cis", precompute_freqs_cis(self.head_dim//2, max_seq_len))
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Create position indices that reset at sequence boundaries
        positions = torch.arange(total_seq_len, device=x.device)
        seq_positions = torch.zeros_like(positions)
        
        # Compute positions within each sequence
        for i in range(len(cu_seqlens)-1):
            start, end = cu_seqlens[i], cu_seqlens[i+1]
            seq_len = end - start
            seq_positions[start:end] = torch.arange(seq_len, device=x.device)
        
        # Ensure positions don't exceed our precomputed table
        seq_positions = seq_positions.clamp(0, self.max_seq_len - 1)
        
        # Apply RoPE to q and k
        freqs_cis = self.freqs_cis[seq_positions]  # [total_seq_len, head_dim//2]
        q = apply_rotary_emb(q, freqs_cis)
        k = apply_rotary_emb(k, freqs_cis)
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))


In [None]:
from flash_attn.layers.rotary import RotaryEmbedding

class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        # Initialize RotaryEmbedding
        self.rotary = RotaryEmbedding(dim=self.head_dim, interleaved=False)
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Create position indices that reset at sequence boundaries
        positions = torch.arange(total_seq_len, device=x.device)
        seq_positions = torch.zeros_like(positions)
        
        # Compute positions within each sequence
        for i in range(len(cu_seqlens)-1):
            start, end = cu_seqlens[i], cu_seqlens[i+1]
            seq_positions[start:end] = torch.arange(end-start, device=x.device)
        
        # Update the rotary embedding cache if needed
        self.rotary._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype)
        
        # Apply rotary embeddings
        # For packed sequences with variable lengths, we use seqlen_offset to handle position reset
        # and cu_seqlens to properly apply rotary embeddings per sequence
        q = apply_rotary_emb_func(
            q, 
            self.rotary._cos_cached, 
            self.rotary._sin_cached,
            interleaved=self.rotary.interleaved,
            inplace=True,
            seqlen_offsets=0,  # We're handling offsets with seq_positions
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen
        )
        
        k = apply_rotary_emb_func(
            k, 
            self.rotary._cos_cached, 
            self.rotary._sin_cached,
            interleaved=self.rotary.interleaved,
            inplace=True,
            seqlen_offsets=0,  # We're handling offsets with seq_positions
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen
        )
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))


In [None]:
class SimpleGPTModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_heads, num_layers, max_seq_len):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        
        self.layers = nn.ModuleList([
            FlashCausalAttentionBlock(hidden_dim, num_heads, max_seq_len=max_seq_len) 
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, input_ids, cu_seqlens, max_seqlen):
        # Get embeddings
        x = self.token_embedding(input_ids)
        
        # Apply attention layers with RoPE
        for layer in self.layers:
            x = layer(x, cu_seqlens, max_seqlen)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits

In [None]:
# Initialize model
model = SimpleGPTModel(vocab_size=1000, hidden_dim=256, num_heads=8, num_layers=4, max_seq_len=100)

# Training loop
for batch in gpt_packer:
    input_ids = batch["input_ids"]
    cu_seqlens = batch["cu_seqlens"]
    max_seqlen = batch["max_seqlen"]
    
    # For autoregressive training:
    # - Inputs are all tokens except the last one in each sequence
    # - Targets are all tokens except the first one in each sequence
    
    # This requires careful handling with packed sequences
    # We can use cu_seqlens to identify sequence boundaries
    
    # Get logits from model
    logits = model(input_ids.squeeze(0), cu_seqlens, max_seqlen[0])
    
    # Compute loss (autoregressive next-token prediction)
    # This is complex with packed sequences - you need to shift targets
    # and handle sequence boundaries carefully
    
    # Simplified loss calculation (would need refinement)
    targets = torch.roll(input_ids, shifts=-1)
    # Mask out padding and sequence boundaries
    mask = (input_ids != -1) & (torch.arange(input_ids.size(0)) + 1 != cu_seqlens[1:].unsqueeze(1))
    loss = F.cross_entropy(logits[mask], targets[mask])

NameError: name 'apply_rotary_emb_func' is not defined

In [None]:
import pdb; pdb.pm()

> [32m/home/galopy/miniforge3/envs/fromscratch/lib/python3.12/site-packages/flash_attn/layers/rotary.py[39m([92m415[39m)[36m_update_cos_sin_cache[39m[34m()[39m
[32m    413[39m         [38;5;66;03m# or if we're switching from inference mode to training[39;00m
[32m    414[39m         if (
[32m--> 415[39m             seqlen > self._seq_len_cached
[32m    416[39m             [38;5;28;01mor[39;00m self._cos_cached [38;5;28;01mis[39;00m [38;5;28;01mNone[39;00m
[32m    417[39m             [38;5;28;01mor[39;00m self._cos_cached.device != device

ipdb> seqlen
[80]
ipdb> q


In [None]:
class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        # Initialize RotaryEmbedding
        self.rotary = RotaryEmbedding(dim=self.head_dim, interleaved=False)
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Create position indices that reset at sequence boundaries
        positions = torch.arange(total_seq_len, device=x.device)
        seq_positions = torch.zeros_like(positions)
        
        # Compute positions within each sequence
        for i in range(len(cu_seqlens)-1):
            start, end = cu_seqlens[i], cu_seqlens[i+1]
            seq_positions[start:end] = torch.arange(end-start, device=x.device)
        
        # Handle max_seqlen if it's a list - use the maximum value
        if isinstance(max_seqlen, list):
            max_seq_len_value = max(max_seqlen)
        else:
            max_seq_len_value = max_seqlen
            
        # Update the rotary embedding cache if needed
        self.rotary._update_cos_sin_cache(max_seq_len_value, device=x.device, dtype=x.dtype)
        
        # Apply rotary embeddings
        q = apply_rotary_emb_func(
            q, 
            self.rotary._cos_cached, 
            self.rotary._sin_cached,
            interleaved=self.rotary.interleaved,
            inplace=True,
            seqlen_offsets=0,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seq_len_value
        )
        
        k = apply_rotary_emb_func(
            k, 
            self.rotary._cos_cached, 
            self.rotary._sin_cached,
            interleaved=self.rotary.interleaved,
            inplace=True,
            seqlen_offsets=0,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seq_len_value
        )
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seq_len_value,
            max_seqlen_k=max_seq_len_value,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))

class SimpleGPTModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_heads, num_layers, max_seq_len):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        
        self.layers = nn.ModuleList([
            FlashCausalAttentionBlock(hidden_dim, num_heads, max_seq_len=max_seq_len) 
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, input_ids, cu_seqlens, max_seqlen):
        # Get embeddings
        x = self.token_embedding(input_ids)
        
        # Apply attention layers with RoPE
        for layer in self.layers:
            x = layer(x, cu_seqlens, max_seqlen)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits

# In your training loop
for batch in gpt_packer:
    input_ids = batch["input_ids"]
    cu_seqlens = batch["cu_seqlens"]
    max_seqlen = batch["max_seqlen"]
    
    # If max_seqlen is a list, extract the integer values
    if isinstance(max_seqlen, list):
        max_seqlen_values = [x.item() if isinstance(x, torch.Tensor) else x for x in max_seqlen]
    else:
        max_seqlen_values = max_seqlen
    
    # Get logits from model
    logits = model(input_ids.squeeze(0), cu_seqlens, max_seqlen_values)
    
    # Rest of your training code


TypeError: '>' not supported between instances of 'list' and 'int'

In [None]:
from flash_attn.ops.triton.rotary import apply_rotary
# or from the context
from rotary_embedding_source import RotaryEmbedding, apply_rotary_emb

class FlashCausalAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.0, max_seq_len=2048):
        super().__init__()
        if hidden_dim % num_heads != 0: raise Exception("hidden_dim not divisible by num_heads")
        self.head_dim = hidden_dim // num_heads
        self.num_heads = num_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)
        self.Wo = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        # Initialize rotary embeddings
        self.rotary_emb = RotaryEmbedding(
            dim=self.head_dim,
            base=10000.0,
            interleaved=False,
            device=next(self.parameters()).device
        )
        
    def forward(self, x, cu_seqlens, max_seqlen):
        if not FLASH_ATTN_AVAILABLE: raise ImportError("Flash Attention is not available")
        
        total_seq_len, hidden_dim = x.shape
        q,k,v = self.Wq(x), self.Wk(x), self.Wv(x)
        
        # Reshape for multi-head attention
        q = q.view(total_seq_len, self.num_heads, self.head_dim)
        k = k.view(total_seq_len, self.num_heads, self.head_dim)
        v = v.view(total_seq_len, self.num_heads, self.head_dim)
        
        # Update rotary embedding cache if needed
        self.rotary_emb._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype)
        
        # Apply rotary embeddings
        q = apply_rotary_emb(
            q, 
            self.rotary_emb._cos_cached, 
            self.rotary_emb._sin_cached,
            interleaved=False,
            inplace=True,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen
        )
        
        k = apply_rotary_emb(
            k, 
            self.rotary_emb._cos_cached, 
            self.rotary_emb._sin_cached,
            interleaved=False,
            inplace=True,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen
        )
        
        # Call Flash Attention
        output = flash_attn_varlen_func(
            q, k, v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            causal=True
        )
        
        return self.dropout(self.Wo(output.reshape(total_seq_len, hidden_dim)))


ModuleNotFoundError: No module named 'rotary_embedding_source'