In [2]:
# !pip install torch transformers datasets

In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# AG News - first 500 articles (very small and fast)
dataset = load_dataset("ag_news", split="train[:500]")
print(f"Number of examples: {len(dataset)}")
print(f"First text preview: {dataset[0]['text'][:200]}...")

Generating train split: 100%|██████████| 120000/120000 [00:00<00:00, 2258391.76 examples/s]
Generating test split: 100%|██████████| 7600/7600 [00:00<00:00, 2550340.86 examples/s]

Number of examples: 500
First text preview: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again....





In [7]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(f"Vocab size: {len(tokenizer)}")

Vocab size: 50257


In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=512)

tokenized_ds = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
print(f"Sample tokenized output (first 20 tokens): {tokenized_ds[0]['input_ids'][:20]}")

Map: 100%|██████████| 500/500 [00:00<00:00, 18039.55 examples/s]

Sample tokenized output (first 20 tokens): [22401, 520, 13, 15682, 30358, 5157, 20008, 262, 2619, 357, 12637, 8, 8428, 532, 10073, 12, 7255, 364, 11, 5007]





In [10]:
block_size = 128

def group_texts(examples):
    concatenated_inputs = sum(examples["input_ids"], [])
    concatenated_masks = sum(examples["attention_mask"], [])
    
    total_len = (len(concatenated_inputs) // block_size) * block_size
    
    # Handle case where total_len is 0 (inputs shorter than block_size)
    if total_len == 0:
        return {"input_ids": [], "attention_mask": []}
    
    concatenated_inputs = concatenated_inputs[:total_len]
    concatenated_masks = concatenated_masks[:total_len]
    
    result_input_ids = [concatenated_inputs[i:i+block_size] for i in range(0, total_len, block_size)]
    result_masks = [concatenated_masks[i:i+block_size] for i in range(0, total_len, block_size)]
    
    return {"input_ids": result_input_ids, "attention_mask": result_masks}

lm_ds = tokenized_ds.map(
    group_texts, 
    batched=True, 
    batch_size=1000,
    remove_columns=tokenized_ds.column_names  # Remove all old columns
)
print(f"Number of training sequences: {len(lm_ds)}")

Map: 100%|██████████| 500/500 [00:00<00:00, 6511.18 examples/s]

Number of training sequences: 222





In [11]:
def collate_fn(batch):
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    return {"input_ids": input_ids, "labels": input_ids.clone()}

train_loader = DataLoader(lm_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
print(f"Total batches: {len(train_loader)}")

Total batches: 28


In [12]:
for batch in train_loader:
    print(f"Input shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    print(f"Sample tokens: {batch['input_ids'][0][:10]}")
    break

Input shape: torch.Size([8, 128])
Labels shape: torch.Size([8, 128])
Sample tokens: tensor([ 530,  326,  338, 1762, 1377,  318, 9431,  465,  393,  607])
