In [2]:
# !pip install transformers, AutoTokenizer, torch


In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

In [4]:
# 1. Load a text dataset (we use a small example dataset for demonstration)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")  # raw text WikiText-2
print(f"Number of lines in dataset: {len(dataset)}")

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Number of lines in dataset: 36718


In [5]:
# 2. Initialize a tokenizer (we'll use GPT-2's tokenizer for compatibility with a GPT-2 model)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 doesn't have a pad by default

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [9]:
# 3. Tokenize the dataset efficiently using `.map` with batched processing
def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=False)

tokenized_ds = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# The dataset now has columns like 'input_ids' and 'attention_mask'

print(tokenized_ds[0]["input_ids"][:20])  # print first 20 token IDs of first example for sanity check

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (647 > 512). Running this sequence through the model will result in indexing errors


[101, 102]


In [16]:
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
print(f"Number of lines in dataset: {len(dataset)}")



Number of lines in dataset: 36718


In [17]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=False  # IMPORTANT
    )

tokenized_ds = dataset.map(tokenize_function, batched=True)
print("Tokenization complete.")


Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Tokenization complete.


In [22]:
block_size = 256

def group_texts(examples):
    # Concatenate ALL tokens
    concatenated_ids = sum(examples["input_ids"], [])
    concatenated_mask = sum(examples["attention_mask"], [])

    # Round down to nearest block
    total_len = (len(concatenated_ids) // block_size) * block_size
    concatenated_ids = concatenated_ids[:total_len]
    concatenated_mask = concatenated_mask[:total_len]

    # Split into chunks
    ids_chunks = [
        concatenated_ids[i:i + block_size]
        for i in range(0, total_len, block_size)
    ]
    mask_chunks = [
        concatenated_mask[i:i + block_size]
        for i in range(0, total_len, block_size)
    ]

    # MUST RETURN FLAT LIST OF CHUNKS
    return {
        "input_ids": ids_chunks,
        "attention_mask": mask_chunks
    }


In [23]:
lm_ds = tokenized_ds.map(
    group_texts,
    batched=True,
    batch_size=1000,
    remove_columns=tokenized_ds.column_names
)

print("LM training sequences:", len(lm_ds))


Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

LM training sequences: 9260


In [20]:
def group_texts(examples):
    # Concatenate
    concatenated = []
    for ids in examples["input_ids"]:
        concatenated.extend(ids)

    total_len = (len(concatenated) // block_size) * block_size
    concatenated = concatenated[:total_len]

    # Split
    chunks = [
        concatenated[i:i + block_size]
        for i in range(0, total_len, block_size)
    ]

    return {"input_ids": chunks}


In [25]:
from torch.utils.data import Dataset, DataLoader
import torch

# Wrap HF dataset to use with PyTorch DataLoader
class HFDatasetWrapper(Dataset):
    def __init__(self, hf_dataset):
        self.ds = hf_dataset
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        return self.ds[idx]

wrapped_ds = HFDatasetWrapper(lm_ds)

# Collate function
def collate_fn(batch):
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    return {"input_ids": input_ids, "labels": input_ids.clone()}

train_dataloader = DataLoader(
    wrapped_ds,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)


In [27]:

# 6. Iterate through a couple of batches to see that it works
for batch in train_dataloader:
    print(batch["input_ids"].shape, batch["labels"].shape)
    break


torch.Size([16, 256]) torch.Size([16, 256])
