# English-Gujarati NMT: Train Student Model with Knowledge Distillation

This notebook trains the small student model (<50M params) using knowledge distillation from the teacher model.

In [None]:
# Install dependencies
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install transformers tokenizers sentencepiece datasets sacrebleu pyyaml tqdm wandb requests

In [None]:
# Setup
import sys
from pathlib import Path
import torch

sys.path.insert(0, str(Path.cwd()))

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Load config and check distillation settings
from src.utils.config import load_config

config = load_config("config.yaml")

# Verify distillation is enabled
if not config.get('distillation', {}).get('enabled', False):
    print("WARNING: Distillation is not enabled in config.yaml!")
    print("Set distillation.enabled: true and provide teacher_checkpoint path")
else:
    teacher_checkpoint = config['distillation'].get('teacher_checkpoint')
    if not teacher_checkpoint or not Path(teacher_checkpoint).exists():
        print("WARNING: Teacher checkpoint not found!")
        print(f"Expected at: {teacher_checkpoint}")
        print("Please upload the teacher checkpoint or update config.yaml")
    else:
        print(f"Distillation enabled. Teacher checkpoint: {teacher_checkpoint}")

print(f"\nDistillation config:")
print(f"  Enabled: {config.get('distillation', {}).get('enabled', False)}")
print(f"  Temperature: {config.get('distillation', {}).get('temperature', 4.0)}")
print(f"  Alpha: {config.get('distillation', {}).get('alpha', 0.5)}")

In [None]:
# Load tokenizers
from src.tokenization.bpe import BPETokenizer
from src.tokenization.unigram import UnigramTokenizer

splits_dir = Path(config['paths']['splits_dir'])
tokenizer_type = config['tokenization']['type']

if tokenizer_type == "bpe":
    source_tokenizer = BPETokenizer()
    target_tokenizer = BPETokenizer()
    source_tokenizer.load(splits_dir / f"source_tokenizer_{tokenizer_type}.json")
    target_tokenizer.load(splits_dir / f"target_tokenizer_{tokenizer_type}.json")
else:
    source_tokenizer = UnigramTokenizer()
    target_tokenizer = UnigramTokenizer()
    source_tokenizer.load(splits_dir / f"source_tokenizer_{tokenizer_type}.model")
    target_tokenizer.load(splits_dir / f"target_tokenizer_{tokenizer_type}.model")

src_vocab_size = source_tokenizer.get_vocab_size()
tgt_vocab_size = target_tokenizer.get_vocab_size()

print(f"Source vocab size: {src_vocab_size}")
print(f"Target vocab size: {tgt_vocab_size}")

In [None]:
# Create datasets
from src.data.dataset import ParallelDataset, get_dataloader

train_dataset = ParallelDataset(
    splits_dir / "train.source",
    splits_dir / "train.target",
    source_tokenizer,
    target_tokenizer,
    max_length=config['dataset']['max_length']
)

val_dataset = ParallelDataset(
    splits_dir / "val.source",
    splits_dir / "val.target",
    source_tokenizer,
    target_tokenizer,
    max_length=config['dataset']['max_length']
)

train_dataloader = get_dataloader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True
)

val_dataloader = get_dataloader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False
)

print(f"Train batches: {len(train_dataloader)}")
print(f"Val batches: {len(val_dataloader)}")

In [None]:
# Create student model
from src.models.transformer import create_student_model, create_teacher_model

student_model = create_student_model(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=config['model']['student']['d_model'],
    nhead=config['model']['student']['num_heads'],
    num_layers=config['model']['student']['num_layers'],
    dim_feedforward=config['model']['student']['d_ff'],
    max_seq_length=config['model']['student']['max_seq_length'],
    dropout=config['model']['student']['dropout'],
    pad_token_id=source_tokenizer.pad_token_id
)

student_model = student_model.to(device)
print(f"Student model created and moved to {device}")

In [None]:
# Load teacher model for distillation
teacher_model = None
if config.get('distillation', {}).get('enabled', False):
    teacher_checkpoint_path = config['distillation'].get('teacher_checkpoint')
    if teacher_checkpoint_path and Path(teacher_checkpoint_path).exists():
        print(f"Loading teacher model from {teacher_checkpoint_path}...")
        checkpoint = torch.load(teacher_checkpoint_path, map_location=device)
        
        teacher_model = create_teacher_model(
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            d_model=config['model']['teacher']['d_model'],
            nhead=config['model']['teacher']['num_heads'],
            num_layers=config['model']['teacher']['num_layers'],
            dim_feedforward=config['model']['teacher']['d_ff'],
            max_seq_length=config['model']['teacher']['max_seq_length'],
            dropout=config['model']['teacher']['dropout'],
            pad_token_id=source_tokenizer.pad_token_id
        )
        teacher_model.load_state_dict(checkpoint['model_state_dict'])
        teacher_model = teacher_model.to(device)
        teacher_model.eval()
        print("Teacher model loaded successfully!")
    else:
        print("WARNING: Teacher checkpoint not found. Training without distillation.")
else:
    print("Distillation disabled. Training student model normally.")

In [None]:
# Setup optimizer and trainer
from torch.optim import AdamW
from src.training.trainer import NMTTrainer

optimizer = AdamW(
    student_model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

trainer = NMTTrainer(
    model=student_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    device=device,
    config=config,
    teacher_model=teacher_model,
    target_tokenizer=target_tokenizer
)

print("Trainer initialized")
if teacher_model:
    print("Knowledge distillation enabled!")

In [None]:
# Train student model
print("Starting student model training...")
trainer.train(num_epochs=config['training']['num_epochs'])

print("\nStudent model training completed!")
print(f"Best BLEU: {trainer.best_bleu:.2f}")
print(f"Best model saved to: {Path(config['paths']['checkpoint_dir']) / 'best_model.pt'}")

## Student Model Training Complete!

**Next Steps:**
1. Download the student checkpoint
2. Use notebook `04_evaluate_and_report.ipynb` to evaluate models and generate reports