In [1]:
import os

# os.chdir("/Users/navneetsingh/Library/Mobile Documents/com~apple~CloudDocs/Work/GitHub-Repos/personal_repos/BITS/kd-slm-training")
os.chdir("/home/work/repos/4_02_2025/kdsml")

In [2]:
import torch
from tqdm import tqdm
from models.teacher_model import TeacherModel
from models.student_model import StudentModel
from utils.loss_functions import KnowledgeDistillationLoss
from training.trainer import Trainer
from training.validator import Validator
from evaluation.evaluator import Evaluator
from utils.data_loader import DataLoader

from transformers import logging as hf_logging
hf_logging.set_verbosity_error()


import matplotlib.pyplot as plt

In [3]:
# Initialize models, loss, and optimizer
teacher_model = TeacherModel().to('cuda')
student_model = StudentModel().to('cuda')



In [4]:
loss_fn = KnowledgeDistillationLoss().to('cuda')  # Move loss function to GPU
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)

In [5]:
# Initialize trainer and validator
trainer = Trainer(teacher_model, student_model, loss_fn, optimizer, 'cuda')
validator = Validator(student_model, loss_fn, 'cuda')

In [6]:
# Load and preprocess training data (WikiText-103)
train_loader = DataLoader(dataset_name='wikitext').get_dataloader(split='train', batch_size=256)
val_loader = DataLoader(dataset_name='wikitext').get_dataloader(split='validation', batch_size=256)

In [7]:
# Load and preprocess evaluation data (SQuAD)
# eval_loader = DataLoader(dataset_name='squad').get_dataloader(split='validation', batch_size=16)

In [8]:
# --- Dummy Forward Pass Test ---
def test_dummy_forward():
    # Create dummy inputs that mimic a real batch
    batch_size = 2
    seq_length = 10
    vocab_size = 30522

    # Dummy input_ids and attention_mask on device
    dummy_input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).to(device)
    dummy_attention_mask = torch.ones_like(dummy_input_ids).to(device)
    # Create dummy labels by shifting the input_ids (simulating language modeling)
    dummy_labels = dummy_input_ids.clone()
    dummy_labels[:, :-1] = dummy_input_ids[:, 1:]
    dummy_labels[:, -1] = 0  # Assume pad token id is 0

    # Teacher forward pass (ensure outputs are on the same device)
    with torch.no_grad():
        teacher_outputs = teacher_model(dummy_input_ids, dummy_attention_mask)
    teacher_hidden_states = teacher_outputs['last_hidden_state']

    # Student forward pass
    student_logits = student_model(dummy_input_ids, dummy_attention_mask)

    # Compute loss to verify that all tensors are on the same device
    loss_val = loss_fn(student_logits, teacher_hidden_states, dummy_labels)
    print("Dummy forward pass successful. Loss:", loss_val.item())

In [9]:
# --- Configuration ---
num_epochs = 10
checkpoint_path = os.path.join("checkpoints", "checkpoint.pth")
device = 'cuda'

# Initialize metrics containers
train_losses = []
val_losses = []
val_accuracies = []

In [10]:
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    student_model.load_state_dict(checkpoint['student_model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    val_accuracies = checkpoint.get('val_accuracies', [])
    print(f"Resuming training from epoch {start_epoch}")
else:
    print(f'New run with new weights')

New run with new weights


In [11]:
for epoch in range(start_epoch, num_epochs):
    student_model.train()
    total_train_loss = 0.0
    epoch_steps = len(train_loader)
    
    # Progress bar for batches in the epoch
    pbar = tqdm(enumerate(train_loader), total=epoch_steps, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    for i, batch in pbar:
        # Ensure batch tensors are on the correct device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Perform a training step and update loss
        train_loss = trainer.train_step(input_ids, attention_mask, labels)
        total_train_loss += train_loss
        
        # Update progress bar with current iteration loss
        pbar.set_postfix({
            'Iter': f"{i+1}/{epoch_steps}",
            'Loss': f"{train_loss:.4f}"
        })
    
    # Compute average training loss for the epoch
    avg_train_loss = total_train_loss / epoch_steps
    train_losses.append(avg_train_loss)
    
    # --- Validation ---
    # We assume that validator.validate returns a tuple: (loss, accuracy)
    val_loss, val_accuracy = validator.validate(val_loader)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    # Print epoch summary with additional metrics
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  Val Acc:    {val_accuracy:.2f}%")
    
    # Save checkpoint including current epoch, model, optimizer, and metric histories
    checkpoint = {
        'epoch': epoch,
        'student_model_state_dict': student_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1} to '{checkpoint_path}'\n")

Epoch 1/10:   1%|▏         | 95/7037 [02:54<3:32:48,  1.84s/batch, Iter=95/7037, Loss=1.3828]


KeyboardInterrupt: 

In [None]:
# train_losses, val_losses = [], []
# for epoch in range(10):
#     # Train on WikiText-103
#     total_train_loss = 0.0
#     for batch in train_loader:
#         input_ids, attention_mask, labels = (
#             batch['input_ids'].to('cuda'),
#             batch['attention_mask'].to('cuda'),
#             batch['labels'].to('cuda')
#         )
#         train_loss = trainer.train_step(input_ids, attention_mask, labels)
#         total_train_loss += train_loss
#     avg_train_loss = total_train_loss / len(train_loader)
#     train_losses.append(avg_train_loss)

#     # Validate on WikiText-103
#     val_loss = validator.validate(val_loader)
#     val_losses.append(val_loss)

#     print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss}, Val Loss = {val_loss}")