In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

# Load tokenized data
print("Loading tokenized data...")
checkpoint = torch.load('tokenized_data.pt')
data = checkpoint['data']
char_to_idx = checkpoint['char_to_idx']
idx_to_char = checkpoint['idx_to_char']
vocab_size = checkpoint['vocab_size']

print(f"✓ Data loaded: {len(data):,} tokens")
print(f"✓ Vocabulary size: {vocab_size}")

print("\n" + "="*60)
print("CREATING DATASET")
print("="*60)

class CharDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.data) - self.seq_len
    
    def __getitem__(self, idx):
        # Get sequence of seq_len + 1 tokens
        chunk = self.data[idx:idx + self.seq_len + 1]
        
        # Input: first seq_len tokens
        x = chunk[:-1]
        
        # Target: shifted by 1 (next character prediction)
        y = chunk[1:]
        
        return x, y

# Create dataset
seq_len = 64  # Context length
dataset = CharDataset(data, seq_len)

print(f"Sequence length: {seq_len}")
print(f"Total sequences: {len(dataset):,}")

# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

print(f"\nTrain sequences: {len(train_dataset):,}")
print(f"Val sequences: {len(val_dataset):,}")

# Create dataloaders
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"\nBatch size: {batch_size}")
print(f"Train batches: {len(train_loader):,}")
print(f"Val batches: {len(val_loader):,}")

print("\n" + "="*60)
print("EXAMINING ONE BATCH")
print("="*60)

# Get one batch
x_batch, y_batch = next(iter(train_loader))

print(f"Input batch shape: {x_batch.shape}")  # [batch_size, seq_len]
print(f"Target batch shape: {y_batch.shape}")  # [batch_size, seq_len]

# Show first sequence
print("\n" + "="*60)
print("FIRST SEQUENCE IN BATCH")
print("="*60)

x_sample = x_batch[0]
y_sample = y_batch[0]

# Decode to text
x_text = ''.join([idx_to_char[i.item()] for i in x_sample])
y_text = ''.join([idx_to_char[i.item()] for i in y_sample])

print("Input text:")
print(f"'{x_text}'")
print("\nTarget text (shifted by 1):")
print(f"'{y_text}'")

# Show token-by-token
print("\n" + "="*60)
print("TOKEN-BY-TOKEN VIEW (first 10 positions)")
print("="*60)
print("Position | Input Char | Input ID | Target Char | Target ID")
print("-" * 60)
for i in range(10):
    in_char = idx_to_char[x_sample[i].item()]
    in_id = x_sample[i].item()
    tgt_char = idx_to_char[y_sample[i].item()]
    tgt_id = y_sample[i].item()
    print(f"{i:8d} | {in_char:10s} | {in_id:8d} | {tgt_char:11s} | {tgt_id:9d}")

print("\n✓ Notice: Target is input shifted by 1 position")
print("  Model learns: given input, predict next character")

# Save for next step
torch.save({
    'train_loader': train_loader,
    'val_loader': val_loader,
    'vocab_size': vocab_size,
    'idx_to_char': idx_to_char,
    'char_to_idx': char_to_idx,
    'seq_len': seq_len
}, 'data_loaders.pt')

print("\n" + "="*60)
print("STEP 3 COMPLETE ✓")
print("="*60)
print("✓ Dataset created")
print("✓ Split into train/val")
print("✓ DataLoaders ready")
print("✓ Ready to build model!")

Loading tokenized data...
✓ Data loaded: 54,050 tokens
✓ Vocabulary size: 34

CREATING DATASET
Sequence length: 64
Total sequences: 53,986

Train sequences: 48,587
Val sequences: 5,399

Batch size: 32
Train batches: 1,519
Val batches: 169

EXAMINING ONE BATCH
Input batch shape: torch.Size([32, 64])
Target batch shape: torch.Size([32, 64])

FIRST SEQUENCE IN BATCH
Input text:
' cat gently. The cat would purr loudly. Tom loved his cat very m'

Target text (shifted by 1):
'cat gently. The cat would purr loudly. Tom loved his cat very mu'

TOKEN-BY-TOKEN VIEW (first 10 positions)
Position | Input Char | Input ID | Target Char | Target ID
------------------------------------------------------------
       0 |            |        1 | c           |        14
       1 | c          |       14 | a           |        12
       2 | a          |       12 | t           |        29
       3 | t          |       29 |             |         1
       4 |            |        1 | g           |        18
  