In [22]:
from tokenizers import Tokenizer

# Load your tokenizer
tokenizer = Tokenizer.from_file("./TinyStories_tokenizer.json")

# Enable truncation (but not padding - we'll handle that in packing)
tokenizer.enable_truncation(max_length=512)

endoftext_token = tokenizer.encode("<|endoftext|>").ids  # This is the end of text token


In [23]:
import numpy as np
from itertools import chain

def pack_sequences(examples, max_length=512, pad_token_id=endoftext_token[0]):
    # Flatten all input_ids and attention_masks
    input_ids = list(chain(*examples["input_ids"]))
    attention_mask = list(chain(*examples["attention_mask"]))
    
    # Calculate number of chunks with stride = max_length
    # Each sequence will be max_length+1 tokens long (input + prediction target)
    num_chunks = (len(input_ids) - 1) // max_length  # We need at least 1 extra token
    
    # Initialize containers for packed sequences
    packed_input_ids = []
    packed_attention_mask = []
    
    for i in range(num_chunks):
        start_pos = i * max_length
        end_pos = start_pos + max_length + 1  # +1 for prediction target
        
        # Extract the sequence
        chunk_input_ids = input_ids[start_pos:end_pos]
        chunk_attention_mask = attention_mask[start_pos:end_pos]
        
        # Pad if necessary (only for the last chunk)
        if len(chunk_input_ids) < max_length + 1:
            pad_len = (max_length + 1) - len(chunk_input_ids)
            chunk_input_ids.extend([pad_token_id] * pad_len)
            chunk_attention_mask.extend([0] * pad_len)
        
        packed_input_ids.append(chunk_input_ids)
        packed_attention_mask.append(chunk_attention_mask)
    
    # Convert to numpy arrays
    packed_input_ids = np.array(packed_input_ids)
    packed_attention_mask = np.array(packed_attention_mask)
    
    # Split into inputs and targets
    return {
        "input_ids": packed_input_ids[:, :-1],  # All tokens except last
        "attention_mask": packed_attention_mask[:, :-1],
        "labels": packed_input_ids[:, 1:],  # All tokens except first (shifted by 1)
    }

In [24]:
from datasets import load_dataset


# Load dataset
dataset = load_dataset("text", data_files={"train": "../data/TinyStories/TinyStoriesV2-GPT4-train.txt", "valid": "../data/TinyStories/TinyStoriesV2-GPT4-valid.txt"})

# First tokenize without padding
def tokenize_function(examples):
    # Tokenize the batch
    encodings = tokenizer.encode_batch_fast(examples["text"])
    
    # Convert to dictionary format
    return {
        "input_ids": [encoding.ids for encoding in encodings],
        "attention_mask": [encoding.attention_mask for encoding in encodings],
    }

# Tokenize the dataset
tokenized = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Shuffle before packing to get better packing efficiency
shuffled = tokenized.shuffle(seed=42)

# Then apply packing
packed_dataset = shuffled.map(
    lambda x: pack_sequences(x),
    batched=True,
    batch_size=1000  # Adjust based on your memory
)

# Convert to PyTorch format
packed_dataset.set_format("torch")

packed_dataset.save_to_disk("packed_dataset")

Map: 100%|██████████| 15600057/15600057 [05:41<00:00, 45656.77 examples/s]
Map: 100%|██████████| 157832/157832 [00:03<00:00, 47356.61 examples/s]
Saving the dataset (14/14 shards): 100%|██████████| 1022039/1022039 [00:14<00:00, 68402.30 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 10326/10326 [00:00<00:00, 117611.04 examples/s]


In [25]:
# Usage:

from datasets import load_from_disk
packed_dataset2 = load_from_disk("packed_dataset")
packed_dataset2.set_format('torch')

from torch.utils.data import DataLoader
# Create DataLoader
dataloader_train = DataLoader(packed_dataset2["train"], batch_size=8, shuffle=True)
dataloader_valid = DataLoader(packed_dataset2["valid"], batch_size=8, shuffle=False)


In [26]:
# look at the first batch
for batch in dataloader_train:
    first_batch = batch
    break

In [39]:
first_batch

{'input_ids': tensor([[  16,    6, 4113,  ...,   18, 4145,  256],
         [ 375,  932,  367,  ...,  231,  550,  324],
         [ 987,  326,   68,  ...,  277,  225,  408],
         ...,
         [ 269,  289, 1652,  ...,  238,   68,  334],
         [ 354,  579,  890,  ...,  417,  604,  279],
         [ 258,  429,  289,  ..., 4852,  228, 1381]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'labels': tensor([[   6, 4113,  227,  ..., 4145,  256,  407],
         [ 932,  367,  256,  ...,  550,  324,  566],
         [ 326,   68,  879,  ...,  225,  408,   16],
         ...,
         [ 289, 1652,  227,  ...,   68,  334,  394],
         [ 579,  890,  707,  ...,  604,  279,  517],
         [ 429,  289, 5088,  ...,  228, 1381,  277]])}