# English-Gujarati NMT: Train Teacher Model

This notebook trains the larger teacher model for knowledge distillation.

In [None]:
# Install dependencies
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install transformers tokenizers sentencepiece datasets sacrebleu pyyaml tqdm wandb requests

In [None]:
# Setup
import sys
from pathlib import Path
import torch

sys.path.insert(0, str(Path.cwd()))

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

In [None]:
# Load config and tokenizers
from src.utils.config import load_config
from src.tokenization.bpe import BPETokenizer
from src.tokenization.unigram import UnigramTokenizer
from src.data.dataset import ParallelDataset, get_dataloader
from src.models.transformer import create_teacher_model
from src.training.trainer import NMTTrainer
from torch.optim import AdamW

config = load_config("config.yaml")
splits_dir = Path(config['paths']['splits_dir'])
tokenizer_type = config['tokenization']['type']

# Load tokenizers
if tokenizer_type == "bpe":
    source_tokenizer = BPETokenizer()
    target_tokenizer = BPETokenizer()
    source_tokenizer.load(splits_dir / f"source_tokenizer_{tokenizer_type}.json")
    target_tokenizer.load(splits_dir / f"target_tokenizer_{tokenizer_type}.json")
else:
    source_tokenizer = UnigramTokenizer()
    target_tokenizer = UnigramTokenizer()
    source_tokenizer.load(splits_dir / f"source_tokenizer_{tokenizer_type}.model")
    target_tokenizer.load(splits_dir / f"target_tokenizer_{tokenizer_type}.model")

src_vocab_size = source_tokenizer.get_vocab_size()
tgt_vocab_size = target_tokenizer.get_vocab_size()

print(f"Source vocab size: {src_vocab_size}")
print(f"Target vocab size: {tgt_vocab_size}")

In [None]:
# Create datasets
train_dataset = ParallelDataset(
    splits_dir / "train.source",
    splits_dir / "train.target",
    source_tokenizer,
    target_tokenizer,
    max_length=config['dataset']['max_length']
)

val_dataset = ParallelDataset(
    splits_dir / "val.source",
    splits_dir / "val.target",
    source_tokenizer,
    target_tokenizer,
    max_length=config['dataset']['max_length']
)

train_dataloader = get_dataloader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True
)

val_dataloader = get_dataloader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False
)

print(f"Train batches: {len(train_dataloader)}")
print(f"Val batches: {len(val_dataloader)}")

In [None]:
# Create teacher model
teacher_model = create_teacher_model(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=config['model']['teacher']['d_model'],
    nhead=config['model']['teacher']['num_heads'],
    num_layers=config['model']['teacher']['num_layers'],
    dim_feedforward=config['model']['teacher']['d_ff'],
    max_seq_length=config['model']['teacher']['max_seq_length'],
    dropout=config['model']['teacher']['dropout'],
    pad_token_id=source_tokenizer.pad_token_id
)

teacher_model = teacher_model.to(device)
print(f"Model created and moved to {device}")

In [None]:
# Setup optimizer and trainer
optimizer = AdamW(
    teacher_model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

trainer = NMTTrainer(
    model=teacher_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    device=device,
    config=config,
    teacher_model=None,  # No teacher for teacher training
    target_tokenizer=target_tokenizer
)

print("Trainer initialized")

In [None]:
# Train teacher model
print("Starting teacher model training...")
trainer.train(num_epochs=config['training']['num_epochs'])

print("\nTeacher model training completed!")
print(f"Best BLEU: {trainer.best_bleu:.2f}")
print(f"Best model saved to: {Path(config['paths']['checkpoint_dir']) / 'best_model.pt'}")

## Teacher Model Training Complete!

**Next Steps:**
1. Download the teacher checkpoint from `checkpoints/best_model.pt`
2. Update `config.yaml` to set `distillation.teacher_checkpoint` to the teacher checkpoint path
3. Use notebook `03_train_student.ipynb` to train the student model with knowledge distillation