this script processes text data into huggingface datasets and saves that to the disk.

this script fixes the issue where hf dataset would split on paragraphs instead of whole stories. this script also loads the dataset directly from huggingface

In [1]:
from tokenizers import Tokenizer

# Load your tokenizer
tokenizer = Tokenizer.from_file("./TinyStories_tokenizer.json")

endoftext_token = tokenizer.encode("<|endoftext|>").ids  # This is the end of text token


In [2]:
from datasets import load_dataset_builder

# Load dataset
dataset_builder = load_dataset_builder("roneneldan/TinyStories")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from datasets import load_dataset

# load and split dataset into train and validation
dataset = load_dataset("roneneldan/TinyStories")

In [4]:
def tokenize_function(examples):
    # Tokenize the batch
    encodings = tokenizer.encode_batch_fast(examples["text"])
    
    # Convert to dictionary format
    return {
        "input_ids": [encoding.ids + endoftext_token for encoding in encodings],
    }

# Tokenize the dataset
dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
    num_proc=23
)

In [5]:
%%time

import numpy as np
# stats
print("Dataset size:", len(dataset["validation"]))
print("mean length:", np.mean([len(x) for x in dataset["validation"]["input_ids"]]))
print("stdev length:", np.std([len(x) for x in dataset["validation"]["input_ids"]]))
print("max length:", np.max([len(x) for x in dataset["validation"]["input_ids"]]))
print("min length:", np.min([len(x) for x in dataset["validation"]["input_ids"]]))

Dataset size: 21990
mean length: 213.37717144156434
stdev length: 99.91620630737907
max length: 1087
min length: 16
CPU times: user 3.15 s, sys: 152 ms, total: 3.3 s
Wall time: 2.13 s


In [6]:
%%time

dataset["train"][0]["input_ids"][0]

CPU times: user 894 μs, sys: 65 μs, total: 959 μs
Wall time: 760 μs


330

In [7]:
# %%time

# dataset["train"]["input_ids"][0][0]

In [None]:
def pack_token_lists(stories, max_length=513):
    """
    Packs token lists into batches without exceeding max_length.
    
    Args:
        token_lists: List of lists of token IDs
        max_length: Maximum allowed length for each batch (default: 513)
    
    Returns:
        Dictionary with packed inputs no longer than max_length (not padded)
    """
    # Sort token lists in descending order of length to improve packing efficiency
    stories_len_sorted = sorted(stories["input_ids"], key=len, reverse=True)
    
    inputs = []
    
    for story in stories_len_sorted:
        placed = False
        story_length = len(story)

        if story_length >= max_length:
            # truncate the token list if it exceeds max_length
            story = story[:max_length]
            inputs.append(story)
            placed = True
            continue
        
        # Try to find an existing input that can accommodate this token list
        for input in inputs:
            input_length = len(input)
            if input_length + story_length <= max_length:
                input.extend(story)
                placed = True
                break
                
        # If no existing batch can accommodate, create a new batch
        if not placed:
            inputs.append(story)

    return {
        "packed_inputs": inputs,
    }


In [None]:
dataset = dataset.map(
  pack_token_lists,
  batched=True,
  remove_columns=["input_ids"],
  num_proc=23,
)

Map (num_proc=23): 100%|██████████| 2119719/2119719 [00:10<00:00, 199224.90 examples/s]
Map (num_proc=23): 100%|██████████| 21990/21990 [00:00<00:00, 96749.09 examples/s]


In [46]:
# pad inputs to max length
def pad_sequences(examples, max_length=513, padding_value=endoftext_token[0]):
    """
    Pads sequences to a fixed length.
    
    Args:
        examples: Dictionary containing packed inputs
        max_length: Desired length for padding (default: 513)
    
    Returns:
        Dictionary with padded sequences
    """
    # Pad each sequence to the specified max_length
    padded_inputs = [
        sequence + [padding_value] * (max_length - len(sequence)) if len(sequence) < max_length else sequence[:max_length]
        for sequence in examples["packed_inputs"]
    ]
    
    return {
        "input_ids": padded_inputs,
    }


In [None]:
dataset = dataset.map(
  pad_sequences,
  batched=True,
  remove_columns=["packed_inputs"],
  num_proc=23,
)

Map (num_proc=23): 100%|██████████| 942538/942538 [00:09<00:00, 99052.37 examples/s] 
Map (num_proc=23): 100%|██████████| 9489/9489 [00:00<00:00, 41550.83 examples/s]


In [None]:
dataset.set_format("torch")

In [49]:
dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 942538
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 9489
    })
})

In [126]:
import numpy as np

# create square attention mask for sequence packed inputs
def create_attention_mask(example, padding_value=endoftext_token[0]):
    """
    Creates an attention mask for packed inputs.
    
    Args:
        example: padded input example
    
    Returns:
        Dictionary with attention masks
    """

    # get indexes of padding tokens
    input_ids = np.array(example["input_ids"])
    padding_indexes = np.where(input_ids == endoftext_token[0])[0]

    # Create a square attention mask
    # the attention mask should be 0 if there is a padding token between i and j and 1 otherwise
    attention_mask = np.ones((len(input_ids), len(input_ids)), dtype=np.bool)
    for padding_index in padding_indexes:
        # each story delineated by a padding token
        # set attention to 0 for all tokens outside of the story
        attention_mask[:padding_index+1, padding_index+1:] = 0
        attention_mask[padding_index+1:, :padding_index+1] = 0
    
    return {
        "packed_inputs": example["input_ids"],
        "attention_mask": attention_mask,
    }

In [None]:
dataset = dataset.map(
    create_attention_mask,
    batched=False,
    remove_columns=["input_ids"],
    num_proc=23,
)

Map (num_proc=23): 100%|██████████| 942538/942538 [00:55<00:00, 17087.73 examples/s]
Map (num_proc=23): 100%|██████████| 9489/9489 [00:02<00:00, 4440.32 examples/s]


In [142]:
dataset.save_to_disk("packed_dataset_with_mask")

Saving the dataset (74/74 shards): 100%|██████████| 942538/942538 [00:57<00:00, 16255.35 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 9489/9489 [00:00<00:00, 22509.12 examples/s]


In [143]:
# Usage:

from datasets import load_from_disk
packed_dataset = load_from_disk("packed_dataset_with_mask")
packed_dataset.set_format('torch')

from torch.utils.data import DataLoader
# Create DataLoader
dataloader_train = DataLoader(packed_dataset["train"], batch_size=16, shuffle=True)
dataloader_valid = DataLoader(packed_dataset["validation"], batch_size=16, shuffle=False)


In [144]:
# look at the first batch
for input in dataloader_train:
    first_batch = input
    break

In [145]:
first_batch

{'packed_inputs': tensor([[  337,   228,   426,  ..., 30000, 30000, 30000],
         [  694,   228,   337,  ..., 30000, 30000, 30000],
         [  374,   379,    68,  ...,   250,    18, 30000],
         ...,
         [ 7610,   238,   523,  ..., 30000, 30000, 30000],
         [  374,   379,    68,  ..., 30000, 30000, 30000],
         [  324,   228,   442,  ...,   339,     5, 30000]]),
 'attention_mask': tensor([[[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [False, False, False,  ...,  True, False, False],
          [False, False, False,  ..., False,  True, False],
          [False, False, False,  ..., False, False,  True]],
 
         [[ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          [ True,  True,  True,  ..., False, False, False],
          ...,
          [False, Fals