In [50]:
import torch
from torch.utils.data import Dataset

examples: list[list[int]] = [[1, 2, 3, 4, 9, 2, 5, 6, 7], [5, 2, 1, 3, 7, 11, 23, 21], [4, 2, 8]]
max_seq_len: int = 8
pad_token_id: int = -100
eos_token_id: int = 99

stream_input_ids = []
stream_position_ids = []
stream_seq_ids = []

current_global_seq_id = 0

In [51]:
for seq in examples:
    seq_len = len(seq)
    
    # 1. Input IDs = Seq + EOS token
    full_seq = seq + [eos_token_id]
    stream_input_ids.extend(full_seq)
    
    # 2. Position IDs: [0, 1, 2, 3, 4,..., len-1]
    # Crucial for RoPE! We reset to 0 for every new sequence.
    pos_ids = list(range(len(full_seq)))
    stream_position_ids.extend(pos_ids)
    
    # 3.Sequence IDs: [0, 0 ,0,..] then [1, 1, 1,..]
    stream_seq_ids.extend([current_global_seq_id] * len(full_seq))
    current_global_seq_id += 1

In [52]:
print(f"stream_input_ids: {stream_input_ids}")
print(f"stream_position_ids: {stream_position_ids}")
print(f"stream_seq_ids: {stream_seq_ids}")

stream_input_ids: [1, 2, 3, 4, 9, 2, 5, 6, 7, 99, 5, 2, 1, 3, 7, 11, 23, 21, 99, 4, 2, 8, 99]
stream_position_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3]
stream_seq_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2]


In [53]:
# 3. Chunk the stream into blocks of max_len
packed_input_batches = []
packed_pos_batches = []
packed_seq_batches = []

total_tokens = len(stream_input_ids)

for i in range(0, total_tokens, max_seq_len):
    end_idx = i + max_seq_len
    chunk_input = stream_input_ids[i:end_idx]
    chunk_pos = stream_position_ids[i:end_idx]
    chunk_seq = stream_seq_ids[i:end_idx]
    
    if len(chunk_input) < max_seq_len:
        pad_len = max_seq_len - len(chunk_input)
        chunk_input.extend([pad_token_id] * pad_len)
        chunk_pos.extend([pad_token_id] * pad_len)
        chunk_seq.extend([pad_token_id] * pad_len)
        
    packed_input_batches.append(chunk_input)
    packed_pos_batches.append(chunk_pos)
    packed_seq_batches.append(chunk_seq)

In [54]:
print(f"total_tokens: {total_tokens}")

print(f"packed_input_batches: {packed_input_batches}")
print(f"packed_pos_batches: {packed_pos_batches}")
print(f"packed_seq_batches: {packed_seq_batches}")

total_tokens: 23
packed_input_batches: [[1, 2, 3, 4, 9, 2, 5, 6], [7, 99, 5, 2, 1, 3, 7, 11], [23, 21, 99, 4, 2, 8, 99, -100]]
packed_pos_batches: [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 0, 1, 2, 3, 4, 5], [6, 7, 8, 0, 1, 2, 3, -100]]
packed_seq_batches: [[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 1, 1, 1, 1], [1, 1, 1, 2, 2, 2, 2, -100]]


In [55]:
packed_data = {
    "input_ids": torch.tensor(packed_input_batches, dtype=torch.long),
    "position_ids": torch.tensor(packed_pos_batches, dtype=torch.long),
    "seq_ids": torch.tensor(packed_seq_batches, dtype=torch.long)
}
packed_data

{'input_ids': tensor([[   1,    2,    3,    4,    9,    2,    5,    6],
         [   7,   99,    5,    2,    1,    3,    7,   11],
         [  23,   21,   99,    4,    2,    8,   99, -100]]),
 'position_ids': tensor([[   0,    1,    2,    3,    4,    5,    6,    7],
         [   8,    9,    0,    1,    2,    3,    4,    5],
         [   6,    7,    8,    0,    1,    2,    3, -100]]),
 'seq_ids': tensor([[   0,    0,    0,    0,    0,    0,    0,    0],
         [   0,    0,    1,    1,    1,    1,    1,    1],
         [   1,    1,    1,    2,    2,    2,    2, -100]])}

In [56]:
class PackedDataset(Dataset):
    def __init__(self, packed_data):
        self.input_ids = packed_data["input_ids"]
        self.position_ids = packed_data["position_ids"]
        self.seq_ids = packed_data["seq_ids"]
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, index):
        return {
            "input_ids": self.input_ids[index],
            "position_ids": self.position_ids[index],
            "seq_ids": self.seq_ids[index]
        }

In [57]:
training_data = PackedDataset(packed_data=packed_data)

training_data[2]

{'input_ids': tensor([  23,   21,   99,    4,    2,    8,   99, -100]),
 'position_ids': tensor([   6,    7,    8,    0,    1,    2,    3, -100]),
 'seq_ids': tensor([   1,    1,    1,    2,    2,    2,    2, -100])}

### Mask Generation

In [58]:
seq_ids = packed_data["seq_ids"]

bsz, seq_len = seq_ids.shape
bsz, seq_len

(3, 8)

In [59]:
# ---------------------------------------------------------
# Step 1: Expand IDs to create a grid for comparison
# ---------------------------------------------------------
# We want to compare every token against every other token.
# Shape: (Batch, Seq_Len, 1)

seq_ids_row = seq_ids.unsqueeze(-1)
seq_ids_row

tensor([[[   0],
         [   0],
         [   0],
         [   0],
         [   0],
         [   0],
         [   0],
         [   0]],

        [[   0],
         [   0],
         [   1],
         [   1],
         [   1],
         [   1],
         [   1],
         [   1]],

        [[   1],
         [   1],
         [   1],
         [   2],
         [   2],
         [   2],
         [   2],
         [-100]]])

In [60]:
# Shape: (Batch, 1, Seq_Len)
seq_ids_col = seq_ids.unsqueeze(-2)
seq_ids_col

tensor([[[   0,    0,    0,    0,    0,    0,    0,    0]],

        [[   0,    0,    1,    1,    1,    1,    1,    1]],

        [[   1,    1,    1,    2,    2,    2,    2, -100]]])

In [61]:
# ---------------------------------------------------------
# Step 2: Create the "Same Sequence" Mask (Block Diagonal)
# ---------------------------------------------------------
# Check: Does Token[i] belong to the same sequence as Token[j]?
# Result: (Batch, Seq_Len, Seq_Len) boolean matrix
# We also check (seq_ids_col != -1) to ensure we never attend to padding.

same_seq_mask = (seq_ids_row == seq_ids_col) & (seq_ids_col != -100)
same_seq_mask

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False],
         [False, False,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True,  True,  

In [62]:
# ---------------------------------------------------------
# Step 3: Create the "Causal" Mask (Lower Triangular)
# ---------------------------------------------------------
# Standard look-ahead mask. 
# torch.tril returns the lower triangular part (ones on and below diagonal)
# We create a generic (Seq_Len, Seq_Len) matrix and broadcast it.
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=seq_ids.device)).bool()
causal_mask

tensor([[ True, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True]])

In [63]:
# ---------------------------------------------------------
# Step 4: Combine Masks
# ---------------------------------------------------------
# A token is visible IF (Same Sequence) AND (In the Past)
combined_mask = same_seq_mask & causal_mask
combined_mask = combined_mask.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)

In [64]:
# ---------------------------------------------------------
# Step 5: Convert to Additive Attention Mask
# ---------------------------------------------------------
# In PyTorch attention, we add 0.0 to scores we keep, and -inf to scores we kill.
# Start with a container of all zeros

dtype = torch.float32

final_mask = torch.zeros((bsz, 1, seq_len, seq_len), dtype=dtype, device=seq_ids.device)
final_mask

tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0

In [65]:
combined_mask

tensor([[[[ True, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True]]],


        [[[ True, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False],
          [False, False,  True, False, False, False, False, False],
          [False, False,  True,  True, False, False, False, False],
          [False, False,  True,  True,  True, False, False, False],
          [False, False,  True,  True,  True,  True, False, False],
          [False, False,  True,  True,  True

In [66]:
# 0 means Normal bcz it adds to attention scores
# -inf add to attention scores and make scores infinite and then softmax convert it to zero.


final = final_mask.masked_fill(~combined_mask, value=float("-inf"))
final

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [-inf, -inf, 0., -inf, -inf, -inf, -inf, -inf],
          [-inf, -inf, 0., 0., -inf, -inf, -inf, -inf],
          [-inf, -inf, 0., 0., 0., -inf, -inf, -inf],
          [-inf, -inf, 0., 0., 0., 0., -inf, -inf],
          [-inf, -inf, 0., 0., 0., 0., 0., -inf],
          [-inf, -inf, 0., 0., 0., 0., 0., 0.]]],


        [[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -in

In [67]:
final.shape

torch.Size([3, 1, 8, 8])

In [4]:
def pack_sequences(sequences, max_len, eos_id):
    buffer_input_ids = []
    buffer_position_ids = []
    buffer_seq_ids = []
    
    current_seq_id = 0
    
    for seq in sequences:
        seq_len = len(seq)
        
        # 1. Input IDs = Seq + EOS token
        full_seq = seq + [eos_id]
        buffer_input_ids.extend(full_seq)
        
        # 2. Position IDs: [0, 1, 2, 3, 4,..., len-1]
        # Crucial for RoPE! We reset to 0 for every new sequence.
        pos_ids = list(range(len(full_seq)))
        buffer_position_ids.extend(pos_ids)
        
        # 3.Sequence IDs: [0, 0 ,0,..] then [1, 1, 1,..]
        buffer_seq_ids.extend([current_seq_id] * len(full_seq))
        current_seq_id += 1

In [68]:
mask = torch.full((10, 10), float("-inf"))
mask

tensor([[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]])

In [69]:
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])