# 04: Train Encoder-Decoder Model (T5/BART-style)
This notebook demonstrates how to train an encoder-decoder transformer model using PyTorch.

In [None]:
!pip install torch transformers

In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from models.encoder_decoder import EncoderDecoderModel
from tqdm import tqdm

## Prepare toy translation-like dataset

In [3]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

examples = [
    ("Translate English to French: Hello.", "Bonjour."),
    ("Translate English to French: How are you?", "Comment ça va ?"),
    ("Translate English to French: Thank you!", "Merci !"),
    ("Translate English to French: Goodbye.", "Au revoir."),
    ("Translate English to French: Yes.", "Oui."),
]

## Create custom dataset and dataloader

In [4]:
class Seq2SeqDataset(Dataset):
    def __init__(self, examples, tokenizer, max_length=64):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        src, tgt = self.examples[idx]
        src_ids = self.tokenizer(src, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")["input_ids"].squeeze(0)
        tgt_ids = self.tokenizer(tgt, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")["input_ids"].squeeze(0)
        return src_ids, tgt_ids

dataset = Seq2SeqDataset(examples, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

## Initialize encoder-decoder model

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EncoderDecoderModel(
    vocab_size=tokenizer.vocab_size,
    embed_dim=512,
    enc_layers=4,
    dec_layers=4,
    heads=8,
    ff_dim=1024,
    max_len=64
).to(device)

## Train with teacher forcing

In [6]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
for epoch in range(5):
    model.train()
    total_loss = 0
    for src_ids, tgt_ids in dataloader:
        src_ids, tgt_ids = src_ids.to(device), tgt_ids.to(device)
        logits = model(src_ids, tgt_ids[:, :-1])
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_ids[:, 1:].contiguous().view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: loss = {total_loss / len(dataloader):.4f}")

torch.save(model.state_dict(), "encoder_decoder_translation.pt")
print("Model saved.")

Epoch 1: loss = 9.2512
Epoch 2: loss = 5.6492
Epoch 3: loss = 5.2768
Epoch 4: loss = 4.3701
Epoch 5: loss = 4.2125
Model saved.
