In [None]:
# Suppress PyTorch deprecation warnings from the transformers library
import warnings
warnings.filterwarnings("ignore", message="`torch.utils._pytree._register_pytree_node` is deprecated")

import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gc

from time import time
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import TextDataset, DataCollatorForLanguageModeling
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Configuration parameters for optimization
MAX_SEQ_LENGTH = 64      # Reduced from 128
BATCH_SIZE = 16          # Increased from 8
GRADIENT_ACCUMULATION = 4 # Effective batch size = 16 × 4 = 64
USE_FP16 = True          # Enable mixed precision training
NUM_EPOCHS = 1           # Keep as 1 for demonstration
TEACHER_MODEL = 'distilgpt2'  # Smaller teacher model
DATASET_SIZE_LIMIT = 5000    # Limit dataset size
NUM_WORKERS = 2          # Parallel data loading
EVAL_STEPS = 200         # Evaluate less frequently

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configure Hugging Face cache
import os
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['HF_HOME'] = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")

# Load tokenizer with error handling
try:
    tokenizer = GPT2Tokenizer.from_pretrained(TEACHER_MODEL, local_files_only=False)
    print(f"Successfully loaded the {TEACHER_MODEL} tokenizer")
except Exception as e:
    print(f"Error loading tokenizer: {e}")
    print("Trying alternative model: gpt2")
    try:
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
        print("Successfully loaded the GPT-2 tokenizer")
    except Exception as e2:
        print(f"Error loading alternative tokenizer: {e2}")
        raise Exception("Could not load any tokenizer.")

tokenizer.pad_token = tokenizer.eos_token

# Load a small text dataset and limit its size
try:
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    print("Successfully loaded the dataset")
    
    # Limit dataset size for faster training
    if len(dataset['train']) > DATASET_SIZE_LIMIT:
        dataset['train'] = dataset['train'].select(range(DATASET_SIZE_LIMIT))
    if len(dataset['validation']) > DATASET_SIZE_LIMIT // 10:
        dataset['validation'] = dataset['validation'].select(range(DATASET_SIZE_LIMIT // 10))
    
    print(f"Limited training dataset to {len(dataset['train'])} examples")
    print(f"Limited validation dataset to {len(dataset['validation'])} examples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Creating a simple example dataset instead")
    # Create a simple dataset
    example_texts = [
        {"text": "This is an example text for language modeling."},
        {"text": "Knowledge distillation helps create smaller models."},
        {"text": "AI research focuses on efficient model deployment."}
    ] * 100
    
    from datasets import Dataset
    dataset = {
        'train': Dataset.from_list(example_texts),
        'validation': Dataset.from_list(example_texts[:10])
    }

# More efficient tokenization function
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_SEQ_LENGTH)

# Tokenize datasets with batched processing
tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=1000, remove_columns=['text'])

# Convert to PyTorch Dataset
train_dataset = tokenized_datasets['train'].with_format("torch")
eval_dataset = tokenized_datasets['validation'].with_format("torch")

# Create optimized DataLoaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

eval_dataloader = DataLoader(
    eval_dataset, 
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
# Teacher model (Using distilgpt2 by default for faster training)
class TeacherLLM(nn.Module):
    def __init__(self, model_name=TEACHER_MODEL):
        super(TeacherLLM, self).__init__()
        try:
            self.model = GPT2LMHeadModel.from_pretrained(model_name, local_files_only=False)
            print(f"Successfully loaded {model_name} model")
        except Exception as e:
            print(f"Error loading {model_name} model: {e}")
            fallback_model = 'gpt2'
            if model_name == fallback_model:
                print("Creating a basic model from scratch")
                config = GPT2Config(vocab_size=len(tokenizer))
                self.model = GPT2LMHeadModel(config)
            else:
                print(f"Trying fallback model: {fallback_model}")
                try:
                    self.model = GPT2LMHeadModel.from_pretrained(fallback_model, local_files_only=False)
                    print(f"Successfully loaded {fallback_model} model")
                except Exception as e2:
                    print(f"Error loading fallback model: {e2}")
                    print("Creating a basic model from scratch")
                    config = GPT2Config(vocab_size=len(tokenizer))
                    self.model = GPT2LMHeadModel(config)
        
        self.model.resize_token_embeddings(len(tokenizer))
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

In [None]:
def evaluate(model, dataloader, max_batches=None):
    """Evaluate model with option to limit evaluation batches"""
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if max_batches and i >= max_batches:
                break
                
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = input_ids.clone()
            
            with autocast(enabled=USE_FP16):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
            
            batch_size = input_ids.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size
    
    avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return avg_loss, perplexity

In [None]:
teacher_model = TeacherLLM().to(device)
teacher_optimizer = optim.AdamW(teacher_model.parameters(), lr=5e-5)
teacher_scaler = GradScaler(enabled=USE_FP16)

In [None]:
# Optional fine-tuning with optimized setup
num_epochs = 1  # Minimized for speed
teacher_steps = 0  # Track total steps

for epoch in range(num_epochs):
    teacher_model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(train_dataloader, desc=f"Teacher Training Epoch {epoch+1}")
    for step, batch in enumerate(progress_bar):
        # Limit steps if needed (can uncomment to make even faster)
        # if step >= 100: break
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = input_ids.clone()
        
        # Mixed precision forward pass
        with autocast(enabled=USE_FP16):
            outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
            loss = loss / GRADIENT_ACCUMULATION  # Scale loss for accumulation
        
        # Mixed precision backward pass
        teacher_scaler.scale(loss).backward()
        running_loss += loss.item() * GRADIENT_ACCUMULATION
        
        # Update weights with gradient accumulation
        if (step + 1) % GRADIENT_ACCUMULATION == 0 or (step + 1) == len(train_dataloader):
            teacher_scaler.step(teacher_optimizer)
            teacher_scaler.update()
            teacher_optimizer.zero_grad(set_to_none=True)  # More memory efficient
            
            # Update progress
            progress_bar.set_postfix({"loss": running_loss / (step + 1)})
        
        teacher_steps += 1
        
        # Evaluate periodically rather than every epoch
        if teacher_steps % EVAL_STEPS == 0:
            # Limit eval batches for faster feedback
            avg_loss, perplexity = evaluate(teacher_model, eval_dataloader, max_batches=5)
            print(f"\nStep {teacher_steps}, Eval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")
            teacher_model.train()
            
            # Clean up memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Final evaluation
    avg_loss, perplexity = evaluate(teacher_model, eval_dataloader)
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_dataloader):.4f}, ",
          f"Eval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")

In [None]:
# Student model (even smaller than before)
class StudentLLM(nn.Module):
    def __init__(self):
        super(StudentLLM, self).__init__()
        # Create an even smaller GPT-2 configuration
        small_config = GPT2Config(
            vocab_size=len(tokenizer),
            n_positions=MAX_SEQ_LENGTH * 2,  # Smaller position embeddings
            n_embd=256,     # Even smaller embedding (was 384)
            n_layer=4,      # Fewer layers (was 6)
            n_head=4        # Fewer attention heads (was 6)
        )
        self.model = GPT2LMHeadModel(small_config)
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

In [None]:
student_model = StudentLLM().to(device)
student_optimizer = optim.AdamW(student_model.parameters(), lr=1e-4)  # Slightly higher learning rate
student_scaler = GradScaler(enabled=USE_FP16)  # Separate scaler for student

In [None]:
def kd_loss_fn(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    """Knowledge Distillation Loss with improved efficiency"""
    # Apply temperature scaling to logits
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    
    # Calculate KL divergence between student and teacher
    # Use sum reduction for better numerical stability
    kd_loss = F.kl_div(soft_student.view(-1, soft_student.size(-1)), 
                       soft_teacher.view(-1, soft_teacher.size(-1)), 
                       reduction='sum') * (temperature ** 2)
    kd_loss = kd_loss / soft_student.size(0)  # Normalize by batch size
    
    # Create mask for valid tokens (not padding)
    mask = (labels != -100).float()
    num_valid_tokens = mask.sum()
    
    # Calculate the standard cross-entropy loss for the student
    hard_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), 
                               labels.view(-1), 
                               ignore_index=-100)
    
    # Combine the two losses
    loss = alpha * hard_loss + (1 - alpha) * kd_loss
    return loss

In [None]:
# Optimized student training with early stopping
num_epochs = 1
temperature = 2.0
alpha = 0.5
patience = 3  # Early stopping patience
best_loss = float('inf')
patience_counter = 0
student_steps = 0

for epoch in range(num_epochs):
    student_model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(train_dataloader, desc=f"Student Training Epoch {epoch+1}")
    for step, batch in enumerate(progress_bar):
        # Limit steps if needed (uncomment to make even faster)
        # if step >= 200: break
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = input_ids.clone()
        
        # Mixed precision training
        with autocast(enabled=USE_FP16):
            # Get student outputs
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits if hasattr(student_outputs, 'logits') else student_outputs[0]
            
            # Get teacher outputs (without gradient tracking)
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits if hasattr(teacher_outputs, 'logits') else teacher_outputs[0]
            
            # Calculate knowledge distillation loss
            loss = kd_loss_fn(student_logits, teacher_logits, labels, temperature=temperature, alpha=alpha)
            loss = loss / GRADIENT_ACCUMULATION  # Scale for accumulation
        
        # Backward pass with mixed precision
        student_scaler.scale(loss).backward()
        running_loss += loss.item() * GRADIENT_ACCUMULATION
        
        # Update weights with gradient accumulation
        if (step + 1) % GRADIENT_ACCUMULATION == 0 or (step + 1) == len(train_dataloader):
            student_scaler.step(student_optimizer)
            student_scaler.update()
            student_optimizer.zero_grad(set_to_none=True)
            
            # Update progress bar
            progress_bar.set_postfix({"loss": running_loss / (step + 1)})
        
        student_steps += 1
        
        # Evaluate periodically
        if student_steps % EVAL_STEPS == 0:
            avg_loss, perplexity = evaluate(student_model, eval_dataloader, max_batches=5)
            print(f"\nStep {student_steps}, Eval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")
            
            # Early stopping check
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
                # Could save best model here
                # torch.save(student_model.state_dict(), "best_student_model.pt")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at step {student_steps}")
                    break
            
            student_model.train()
            
            # Clean up memory
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Early stopping check for epoch
    if patience_counter >= patience:
        print("Training stopped early due to no improvement")
        break
    
    # Final evaluation
    avg_loss, perplexity = evaluate(student_model, eval_dataloader)
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_dataloader):.4f}, ",
          f"Eval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")

In [None]:
# Compare model sizes
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

teacher_params = count_parameters(teacher_model)
student_params = count_parameters(student_model)

print(f"Teacher model parameters: {teacher_params:,}")
print(f"Student model parameters: {student_params:,}")
print(f"Size reduction: {(1 - student_params/teacher_params)*100:.2f}%")

In [None]:
# Optimized text generation with caching
input_text = "In this article, we discuss the importance of"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

def generate_text(model, prompt, max_length=50):
    # Free up memory before generation
    gc.collect()
    torch.cuda.empty_cache()
    
    with torch.no_grad(), autocast(enabled=USE_FP16):
        input_ids = prompt['input_ids']
        attention_mask = prompt['attention_mask']
        generated = model.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            use_cache=True,  # Enable KV caching for faster generation
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Time teacher model
%time teacher_output = generate_text(teacher_model, inputs)
print(f"Teacher output:\n{teacher_output}\n")

# Time student model
%time student_output = generate_text(student_model, inputs)
print(f"Student output:\n{student_output}\n")

In [None]:
# Optimized memory usage comparison
def check_model_memory(model_name, model):
    # Force garbage collection
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    
    # Track memory before model forward pass
    memory_before = torch.cuda.memory_allocated(device)
    
    # Run a forward pass with mixed precision
    with autocast(enabled=USE_FP16):
        _ = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
    
    # Get peak memory during the forward pass
    memory_peak = torch.cuda.max_memory_allocated(device)
    memory_used = memory_peak - memory_before
    
    print(f"{model_name} memory usage: {memory_used / 1024**2:.2f} MB")
    
check_model_memory("Teacher", teacher_model)
check_model_memory("Student", student_model)