Improvement: Using sequence packing.

## Data

### Dataset 

In [None]:
from datasets import load_dataset
import tiktoken
import torch

from minai import *

## Sequence packing

Going through https://huggingface.co/blog/sirluk/llm-sequence-packing.

In [None]:
# Setup
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

2025-04-02 20:40:25.674457: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-02 20:40:25.731906: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 20:40:26.099996: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-02 20:40:26.351344: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743644426.469865    9667 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743644426.49

In [None]:
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"

sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences

[[464, 3797, 3332, 319, 262, 2603],
 [464, 3290, 15063, 616, 26131],
 [3666, 25949, 318, 257, 4701]]

In [None]:
tokenizer.eos_token_id

50256

In [None]:
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
tokenized_sentences

[464,
 3797,
 3332,
 319,
 262,
 2603,
 50256,
 464,
 3290,
 15063,
 616,
 26131,
 50256,
 3666,
 25949,
 318,
 257,
 4701,
 50256]

In [None]:
tokenizer.decode(tokenized_sentences)

'The cat sat on the mat<|endoftext|>The dog ate my homework<|endoftext|>My aunt is a teacher<|endoftext|>'

In [None]:
tokenized_sentences = torch.tensor(tokenized_sentences)
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.int).tril()
attn_mask

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,

In [None]:
T = tokenized_sentences.size(0)
T

19

In [None]:
# get indices of all EOS tokens
eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()
eos_indices

tensor([ 6, 12, 18])

In [None]:
eos_indices[1:]

tensor([12, 18])

In [None]:
eos_indices[:-1]

tensor([ 6, 12])

In [None]:
eos_indices[[0]]+1

tensor([7])

In [None]:
# from indices, get length of each sequence
reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])
reps

tensor([7, 6, 6])

In [None]:
torch.repeat_interleave(eos_indices, reps)

tensor([ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18])

In [None]:
torch.repeat_interleave(eos_indices, reps).view(1,-1)

tensor([[ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18]])

In [None]:
# repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)
repeated_idx

tensor([[ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 1

In [None]:
# create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
mask_indices = torch.arange(T).view(-1,1).expand(-1, T)
mask_indices

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7],
        [ 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8],
        [ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
        [11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1

In [None]:
# create causal mask and additionally mask out all tokens from preceeding sequences
mask = torch.ones(T, T).tril().expand(-1, -1)
mask

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [None]:
(mask_indices > repeated_idx).int()

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,

In [None]:
mask.masked_fill_(mask_indices > repeated_idx, False)

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1

In [None]:
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    # store sequence length in variable for easier readability
    T = tokenized_sentences.size(0)
    # get indices of all EOS tokens
    eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()
    # from indices, get length of each sequence
    reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])
    # repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
    repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)
    # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
    mask_indices = torch.arange(T).view(-1,1).expand(-1, T)
    # create causal mask and additionally mask out all tokens from preceeding sequences
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)
    mask.masked_fill_(mask_indices > repeated_idx, False)
    return mask

get_attention_mask_for_packed_sequence(tokenized_sentences, tokenizer.eos_token_id)

tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False],

Position embeddings

In [None]:
pos_ids = torch.arange(T) - torch.repeat_interleave(torch.cat([torch.tensor([0]), eos_indices+1], dim=0)[:-1], reps)
pos_ids

tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5])

Doing it in Batch

In [None]:
sentence4 = "Rome wasn't built in a day"
sentence5 = "My hovercraft is full of eels"

sentences = [sentence4, sentence5]
tokenized_sentences2 = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences2 = torch.tensor([t for s in tokenized_sentences2 for t in s + [tokenizer.eos_token_id]])

batch = torch.nn.utils.rnn.pad_sequence(
  [tokenized_sentences, tokenized_sentences2],
  batch_first=True, padding_value=tokenizer.eos_token_id
)
batch

tensor([[  464,  3797,  3332,   319,   262,  2603, 50256,   464,  3290, 15063,   616, 26131, 50256,  3666, 25949,   318,   257,  4701, 50256],
        [   49,   462,  2492,   470,  3170,   287,   257,  1110, 50256,  3666, 20599,  3323,   318,  1336,   286,   304,  1424, 50256, 50256]])

In [None]:
tokenized_sentences.shape, tokenized_sentences2.shape

(torch.Size([19]), torch.Size([18]))

In [None]:
torch.nn.utils.rnn.pad_sequence([torch.ones(3), torch.ones(10)], padding_value=99).shape

torch.Size([10, 2])

In [None]:
torch.nn.utils.rnn.pad_sequence([torch.ones(3), torch.ones(10)], padding_value=99, batch_first=True).shape

torch.Size([2, 10])

In [None]:
B, T = batch.shape
B, T

(2, 19)

In [None]:
batch.view(-1).shape

torch.Size([38])

In [None]:
(batch.view(-1) == tokenizer.eos_token_id)

tensor([False, False, False, False, False, False,  True, False, False, False, False, False,  True, False, False, False, False, False,  True, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False, False, False, False,  True,  True])

In [None]:
(batch.view(-1) == tokenizer.eos_token_id).nonzero(as_tuple=True)

(tensor([ 6, 12, 18, 27, 36, 37]),)

In [None]:
(batch.view(-1) == tokenizer.eos_token_id).nonzero().flatten() + 1

tensor([ 7, 13, 19, 28, 37, 38])

In [None]:
eos_idx = (batch.view(-1) == tokenizer.eos_token_id).nonzero(as_tuple=True)[0] + 1
eos_idx

tensor([ 7, 13, 19, 28, 37, 38])

In [None]:
eos_idx_expanded = torch.cat(
  [eos_idx, torch.arange(0,B*T+1,T)]
).unique().sort()[0]
eos_idx_expanded

tensor([ 0,  7, 13, 19, 28, 37, 38])

In [None]:
normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
normalized_idx

tensor([19,  7, 13, 19,  9, 18, 19])

In [None]:
reps = normalized_idx[1:] - normalized_idx[:-1]
reps = torch.where(reps < 1, normalized_idx[1:], reps)
reps

tensor([7, 6, 6, 9, 9, 1])

In [None]:
repeated_idx = torch.repeat_interleave(
  normalized_idx[1:], reps
).view(B,1,T).expand(-1,T,-1)
repeated_idx

tensor([[[ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13, 19, 19, 19, 19, 19, 19],
         [ 7,  7,  7,  7,  7,  7,  7, 13, 1

In [None]:
mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
# create mask
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
mask = mask.masked_fill(mask_indices >= repeated_idx, False)
mask

tensor([[[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, 

In [None]:
mask.int()

tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [None]:
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    B, T = x.shape
    eos_idx = (x.view(-1) == token_id).nonzero(as_tuple=True)[0] + eos
    eos_idx_expanded = torch.cat([eos_idx, torch.arange(0,B*T+1,T)]).unique().sort()[0]
    normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
    normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
    reps = normalized_idx[1:] - normalized_idx[:-1]
    reps = torch.where(reps < 1, normalized_idx[1:], reps)
    repeated_idx = torch.repeat_interleave(normalized_idx[1:], reps).view(B,1,T).expand(-1,T,-1)
    mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
    mask = mask.masked_fill(mask_indices >= repeated_idx, False)
    return mask

In [None]:
pos_ids = (torch.arange(B*T) - torch.repeat_interleave(eos_idx_expanded[:-1], reps)).view(B,T)
pos_ids

tensor([[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0]])

In [None]:
get_attention_mask_for_packed_sequence(batch, tokenizer.eos_token_id).int()

tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,