In [2]:
# Quick check to see if GPU is available before we start
!nvidia-smi


Mon Dec 29 12:40:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
# Installing the main dependencies we'll need
# transformers: HuggingFace library for models and tokenizers
# datasets: For loading and processing datasets
# accelerate: Handles mixed precision and multi-GPU training
%pip install transformers datasets evaluate accelerate -q


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# Standard imports for this project
from transformers import (
    AutoModelForMaskedLM, 
    AutoTokenizer, 
    get_scheduler, 
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset
from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import numpy as np
import math
import os
import json
from pathlib import Path
from tqdm.auto import tqdm
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')  # Suppress annoying warnings during training


## Configuration


In [7]:
# ========== CONFIGURATION ==========
class Config:
    # Model
    model_name = 'distilbert-base-uncased'
    
    # Dataset to use for fine-tuning
    dataset_name = 'imdb'
    # Set these to a number if you want to test with a smaller subset first
    max_samples_train = None
    max_samples_eval = None
    
    # Training hyperparameters
    num_epochs = 10
    batch_size = 32
    learning_rate = 5e-5
    weight_decay = 0.01
    warmup_ratio = 0.1  # 10% warmup
    
    # Data processing settings
    max_length = 512  # Maximum sequence length (model's limit)
    mlm_probability = 0.15  # BERT standard: mask 15% of tokens
    
    # Performance optimizations
    mixed_precision = "fp16"  # Use half precision for faster training
    gradient_accumulation_steps = 1  # Simulate larger batch size if needed
    num_workers = 2  # Parallel data loading threads
    pin_memory = True  # Faster data transfer to GPU
    
    # Checkpointing
    output_dir = "./mlm_checkpoints"
    save_steps = 500
    eval_steps = 500
    logging_steps = 100
    
    # Early stopping - stops if no improvement for N epochs
    early_stopping_patience = 3
    early_stopping_threshold = 0.01  # Minimum improvement to count as progress

config = Config()

# Create output directory
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

print(f"Configuration loaded. Output directory: {config.output_dir}")


Configuration loaded. Output directory: ./mlm_checkpoints


In [9]:
# Load tokenizer and model
print(f"Loading tokenizer and model: {config.model_name}")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForMaskedLM.from_pretrained(config.model_name)

print(f"Model loaded. Vocabulary size: {len(tokenizer)}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


Loading tokenizer and model: distilbert-base-uncased
Model loaded. Vocabulary size: 30522
Model parameters: 66,985,530


In [10]:
# Load dataset
print(f"Loading dataset: {config.dataset_name}")
raw_dataset = load_dataset(config.dataset_name)

# For MLM, we only need the text - labels aren't used
raw_dataset = raw_dataset.remove_columns(["label"])

print(f"Dataset loaded. Train: {len(raw_dataset['train'])}, Test: {len(raw_dataset['test'])}")


Loading dataset: imdb


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset loaded. Train: 25000, Test: 25000


In [11]:
# Tokenization function with truncation
def tokenize_function(examples):
    """Convert text strings to token IDs."""
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=config.max_length  # Cut off longer sequences
    )

print("Tokenizing dataset...")
tokenized_dataset = raw_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],  # We don't need raw text anymore
    desc="Tokenizing"
)
print("Tokenization complete.")


Tokenizing dataset...


Tokenizing:   0%|          | 0/25000 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/25000 [00:00<?, ? examples/s]

Tokenizing:   0%|          | 0/50000 [00:00<?, ? examples/s]

Tokenization complete.


In [12]:
# Group texts into chunks of max_length
def group_texts(examples):
    """Merge texts together and split into max_length chunks."""
    # Stick all the token sequences together
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    
    # Round down to a multiple of max_length so we don't have partial chunks
    total_length = len(concatenated['input_ids'])
    total_length = (total_length // config.max_length) * config.max_length
    
    # Split into equal-sized chunks
    result = {
        k: [t[i:i + config.max_length] 
            for i in range(0, total_length, config.max_length)]
        for k, t in concatenated.items()
    }
    
    # For MLM, labels are just a copy of input_ids (before masking)
    result['labels'] = result['input_ids'].copy()
    
    return result

print("Grouping texts into chunks...")
chunked_dataset = tokenized_dataset.map(
    group_texts,
    batched=True,
    desc="Chunking texts"
)
print("Chunking complete.")


Grouping texts into chunks...


Chunking texts:   0%|          | 0/25000 [00:00<?, ? examples/s]

Chunking texts:   0%|          | 0/25000 [00:00<?, ? examples/s]

Chunking texts:   0%|          | 0/50000 [00:00<?, ? examples/s]

Chunking complete.


In [13]:
# Data collator handles batching and masking
# For training: dynamic masking means different masks each epoch = better learning
train_data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm_probability=config.mlm_probability,
    return_tensors="pt"
)

# For eval, we'll pre-mask the data once, so we just need a simple collator
from transformers import default_data_collator
eval_data_collator = default_data_collator

print(f"Train data collator created with MLM probability: {config.mlm_probability}")
print("Eval data collator: default (will use pre-masked dataset)")


Train data collator created with MLM probability: 0.15
Eval data collator: default (will use pre-masked dataset)


In [14]:
# Split into train and eval sets
train_dataset = chunked_dataset['train']
eval_dataset = chunked_dataset['test']

# Optionally limit dataset size for quick testing
if config.max_samples_train is not None:
    train_dataset = train_dataset.select(range(min(config.max_samples_train, len(train_dataset))))
    print(f"⚠️  Limited train dataset to {len(train_dataset)} examples (for faster iteration)")
else:
    print(f"✓ Using full train dataset: {len(train_dataset)} examples")

if config.max_samples_eval is not None:
    eval_dataset = eval_dataset.select(range(min(config.max_samples_eval, len(eval_dataset))))
    print(f"⚠️  Limited eval dataset to {len(eval_dataset)} examples (for faster iteration)")
else:
    print(f"✓ Using full eval dataset: {len(eval_dataset)} examples")

# Pre-mask the eval dataset with fixed masks
# Why? So we're evaluating on the same masked tokens every epoch
# This makes it fair to compare perplexity across epochs
# Training still uses dynamic masking (different masks each epoch = better)
def pre_mask_eval_dataset(examples):
    """Apply masking to eval data once with a fixed seed."""
    import random
    import numpy as np
    
    # Fixed seed = same masks every time
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    # Convert to format the collator expects
    batch = [dict(zip(examples.keys(), values)) for values in zip(*examples.values())]
    
    # Apply masking
    masked_batch = train_data_collator(batch)
    
    # Convert back to lists (datasets library format)
    # Only keep what the model actually needs
    model_keys = ['input_ids', 'attention_mask', 'labels']
    result = {}
    for key in model_keys:
        if key in masked_batch:
            if isinstance(masked_batch[key], torch.Tensor):
                result[key] = [masked_batch[key][i].tolist() for i in range(len(batch))]
            else:
                result[key] = [masked_batch[key][i] for i in range(len(batch))]
    
    return result

print("\nPre-masking evaluation dataset with fixed masks...")
print("This ensures consistent evaluation across epochs for fair comparison.")
print("Using fixed random seed (42) for reproducibility.")
eval_dataset = eval_dataset.map(
    pre_mask_eval_dataset,
    batched=True,
    desc="Pre-masking eval dataset"
)
print("✓ Evaluation dataset pre-masked successfully.")


✓ Using full train dataset: 13428 examples
✓ Using full eval dataset: 13231 examples

Pre-masking evaluation dataset with fixed masks...
This ensures consistent evaluation across epochs for fair comparison.
Using fixed random seed (42) for reproducibility.


Pre-masking eval dataset:   0%|          | 0/13231 [00:00<?, ? examples/s]

✓ Evaluation dataset pre-masked successfully.


## Setup Training


In [15]:
# Set up data loaders for training and evaluation
# Training: shuffle and use dynamic masking
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=config.batch_size,
    collate_fn=train_data_collator,  # This applies masking on-the-fly
    num_workers=config.num_workers,  # Parallel loading
    pin_memory=config.pin_memory,  # Faster GPU transfer
    persistent_workers=True if config.num_workers > 0 else False
)

# Evaluation: no shuffling, data already masked
eval_dataloader = DataLoader(
    eval_dataset,
    shuffle=False,
    batch_size=config.batch_size,
    collate_fn=eval_data_collator,  # Just batches, no masking
    num_workers=config.num_workers,
    pin_memory=config.pin_memory,
    persistent_workers=True if config.num_workers > 0 else False
)

print(f"DataLoaders created:")
print(f"  Train: {len(train_dataloader)} batches (dynamic masking)")
print(f"  Eval: {len(eval_dataloader)} batches (fixed masks)")


DataLoaders created:
  Train: 420 batches (dynamic masking)
  Eval: 414 batches (fixed masks)


In [16]:
# Set up optimizer and learning rate schedule
# Calculate how many steps we'll train for
num_training_steps = len(train_dataloader) * config.num_epochs
num_warmup_steps = int(num_training_steps * config.warmup_ratio)

# AdamW is the standard optimizer for transformer models
optimizer = AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Linear schedule with warmup: gradually increase LR, then decrease
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print(f"Optimizer and scheduler configured.")
print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")


Optimizer and scheduler configured.
Total training steps: 4200
Warmup steps: 420


In [17]:
# Accelerator handles mixed precision (FP16) and multi-GPU automatically
# Makes training faster and uses less memory
accelerator = Accelerator(
    mixed_precision=config.mixed_precision,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    log_with="tensorboard" if os.path.exists("./logs") else None,
    project_dir="./logs"
)

# Wrap everything in accelerator - it handles device placement
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

print(f"Accelerator initialized with mixed precision: {config.mixed_precision}")
print(f"Device: {accelerator.device}")


Accelerator initialized with mixed precision: fp16
Device: cuda


## Training Loop with Early Stopping


In [18]:
# Track metrics during training
training_history = {
    'train_loss': [],
    'eval_loss': [],
    'perplexity': [],
    'learning_rate': []
}

# Early stopping: stop if perplexity doesn't improve for N epochs
best_perplexity = float('inf')
patience_counter = 0

print("=" * 80)
print("Starting Training")
print("=" * 80)


Starting Training


In [19]:
# Main training loop
for epoch in range(config.num_epochs):
    # Training phase
    train_losses = []
    model.train()
    progress_bar = tqdm(
        train_dataloader,
        desc=f"Epoch {epoch} [Train]",
        disable=not accelerator.is_local_main_process
    )
    
    for batch in progress_bar:
        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)  # Handles mixed precision automatically
        train_losses.append(loss.item())
        optimizer.step()
        lr_scheduler.step()
        
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    train_result = np.mean(train_losses)
    training_history['train_loss'].append(train_result)
    training_history['learning_rate'].append(lr_scheduler.get_last_lr()[0])
    print(f"Epoch {epoch} - Train Loss: {train_result:.4f}")

    # Evaluation phase
    model.eval()
    eval_losses = []
    progress_bar = tqdm(
        eval_dataloader,
        desc=f"Epoch {epoch} [Eval]",
        disable=not accelerator.is_local_main_process
    )
    
    for batch in progress_bar:
        with torch.no_grad():  # No gradients needed for eval
            outputs = model(**batch)
        loss = outputs.loss
        # Handle multi-GPU case
        current_batch_size = batch["input_ids"].shape[0]
        eval_losses.append(accelerator.gather(loss.repeat(current_batch_size)))
    
    eval_losses = torch.cat(eval_losses)
    perplexity = math.exp(torch.mean(eval_losses))  # Perplexity = exp(loss)
    
    avg_eval_loss = torch.mean(eval_losses).item()
    training_history['eval_loss'].append(avg_eval_loss)
    training_history['perplexity'].append(perplexity)
    
    print(f'perplexity {perplexity}')
    
    # Check if this is the best model so far
    if perplexity < best_perplexity - config.early_stopping_threshold:
        best_perplexity = perplexity
        patience_counter = 0
        
        # Save checkpoint
        if accelerator.is_main_process:
            unwrapped_model = accelerator.unwrap_model(model)
            checkpoint_dir = os.path.join(config.output_dir, f"checkpoint-epoch-{epoch+1}")
            unwrapped_model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)
            print(f"  ✓ Saved best model (perplexity: {perplexity:.4f}) to {checkpoint_dir}")
    else:
        patience_counter += 1
        if patience_counter >= config.early_stopping_patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs.")
            print(f"Best perplexity: {best_perplexity:.4f}")
            break

print("\nTraining Complete!")


Epoch 0 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 0 - Train Loss: 2.5248


Epoch 0 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 9.891653327026745
  ✓ Saved best model (perplexity: 9.8917) to ./mlm_checkpoints/checkpoint-epoch-1


Epoch 1 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 1 - Train Loss: 2.3565


Epoch 1 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 9.200568011682796
  ✓ Saved best model (perplexity: 9.2006) to ./mlm_checkpoints/checkpoint-epoch-2


Epoch 2 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 2 - Train Loss: 2.2773


Epoch 2 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.820623438740023
  ✓ Saved best model (perplexity: 8.8206) to ./mlm_checkpoints/checkpoint-epoch-3


Epoch 3 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 3 - Train Loss: 2.2325


Epoch 3 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.577262847692973
  ✓ Saved best model (perplexity: 8.5773) to ./mlm_checkpoints/checkpoint-epoch-4


Epoch 4 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 4 - Train Loss: 2.1976


Epoch 4 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.381961392673047
  ✓ Saved best model (perplexity: 8.3820) to ./mlm_checkpoints/checkpoint-epoch-5


Epoch 5 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 5 - Train Loss: 2.1692


Epoch 5 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.224712997318674
  ✓ Saved best model (perplexity: 8.2247) to ./mlm_checkpoints/checkpoint-epoch-6


Epoch 6 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 6 - Train Loss: 2.1433


Epoch 6 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.111735064806625
  ✓ Saved best model (perplexity: 8.1117) to ./mlm_checkpoints/checkpoint-epoch-7


Epoch 7 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 7 - Train Loss: 2.1270


Epoch 7 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 8.011981436307622
  ✓ Saved best model (perplexity: 8.0120) to ./mlm_checkpoints/checkpoint-epoch-8


Epoch 8 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 8 - Train Loss: 2.1088


Epoch 8 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 7.948975941557036
  ✓ Saved best model (perplexity: 7.9490) to ./mlm_checkpoints/checkpoint-epoch-9


Epoch 9 [Train]:   0%|          | 0/420 [00:00<?, ?it/s]

Epoch 9 - Train Loss: 2.0973


Epoch 9 [Eval]:   0%|          | 0/414 [00:00<?, ?it/s]

perplexity 7.920665055462317
  ✓ Saved best model (perplexity: 7.9207) to ./mlm_checkpoints/checkpoint-epoch-10

Training Complete!


## Save Training History


In [20]:
# Save all the training metrics to a JSON file
history_path = os.path.join(config.output_dir, "training_history.json")
with open(history_path, 'w') as f:
    json.dump(training_history, f, indent=2)

print(f"Training history saved to {history_path}")

# Print summary of results
print("\nFinal Metrics:")
print(f"  Best Eval Loss: {min(training_history['eval_loss']):.4f}")
print(f"  Best Perplexity: {min(training_history['perplexity']):.4f}")
print(f"  Final Train Loss: {training_history['train_loss'][-1]:.4f}")


Training history saved to ./mlm_checkpoints/training_history.json

Final Metrics:
  Best Eval Loss: 2.0695
  Best Perplexity: 7.9207
  Final Train Loss: 2.0973
