# Train Yorùbá Seq2Seq Diacritizer

This notebook trains a **Seq2Seq (Encoder-Decoder)** model to restore diacritics to Yorùbá text.

**Why Seq2Seq?**
- Handles different input/output lengths (diacritics can change character count)
- Uses **100% of the data** (676k pairs) instead of 16%
- Expected accuracy: **93-95%**

**Expected results:**
- Training time: ~1-2 hours on Colab GPU
- Word accuracy: 93-95%
- Model size: 3-5MB

In [None]:
# Step 1: Install dependencies
!pip install -q datasets torch

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("No GPU, using CPU (will be slower)")
    device = torch.device("cpu")

In [None]:
# Step 2: Load ALL data from HuggingFace (no length filtering!)
from datasets import load_dataset
import unicodedata

print("Loading dataset...")
ds = load_dataset("bumie-e/Yoruba-diacritics-vs-non-diacritics", split="train")
print(f"Total samples: {len(ds)}")

# Extract ALL text pairs - Seq2Seq handles different lengths!
undiacritized = []
diacritized = []

for item in ds:
    undiac = item.get("no_diacritcs", "")
    diac = item.get("diacritcs", "")
    
    # Normalize to NFC
    undiac = unicodedata.normalize("NFC", undiac)
    diac = unicodedata.normalize("NFC", diac)
    
    # Only skip empty or very long sentences
    if undiac and diac and len(undiac) < 200 and len(diac) < 200:
        undiacritized.append(undiac)
        diacritized.append(diac)

print(f"Valid pairs: {len(undiacritized)} (using ALL data!)")
print(f"\nExample:")
print(f"  Input:  {undiacritized[0]} (len={len(undiacritized[0])})")
print(f"  Output: {diacritized[0]} (len={len(diacritized[0])})")

In [None]:
# Step 3: Build vocabulary with special tokens for Seq2Seq
import json
from typing import Dict, List, Tuple

class CharVocab:
    """Character vocabulary with special tokens for Seq2Seq."""
    PAD = "<PAD>"
    UNK = "<UNK>"
    SOS = "<SOS>"  # Start of sequence
    EOS = "<EOS>"  # End of sequence

    def __init__(self):
        self.char2idx: Dict[str, int] = {
            self.PAD: 0, 
            self.UNK: 1,
            self.SOS: 2,
            self.EOS: 3,
        }
        self.idx2char: Dict[int, str] = {v: k for k, v in self.char2idx.items()}

    def add_char(self, char: str) -> int:
        if char not in self.char2idx:
            idx = len(self.char2idx)
            self.char2idx[char] = idx
            self.idx2char[idx] = char
        return self.char2idx[char]

    def encode(self, text: str, add_special: bool = False) -> List[int]:
        ids = [self.char2idx.get(c, 1) for c in text]
        if add_special:
            ids = [self.char2idx[self.SOS]] + ids + [self.char2idx[self.EOS]]
        return ids

    def decode(self, indices: List[int]) -> str:
        chars = []
        for i in indices:
            if i == self.char2idx[self.EOS]:
                break
            if i not in (self.char2idx[self.PAD], self.char2idx[self.SOS], self.char2idx[self.UNK]):
                chars.append(self.idx2char.get(i, ""))
        return "".join(chars)

    def __len__(self) -> int:
        return len(self.char2idx)

# Build vocabulary from all texts
vocab = CharVocab()
for text in undiacritized + diacritized:
    for char in text:
        vocab.add_char(char)

print(f"Vocabulary size: {len(vocab)} (including PAD, UNK, SOS, EOS)")

In [None]:
# Step 4: Create PyTorch datasets with SOS/EOS for decoder
import random
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class DiacritizationDataset(Dataset):
    def __init__(self, undiac: List[str], diac: List[str], vocab: CharVocab):
        self.undiac = undiac
        self.diac = diac
        self.vocab = vocab

    def __len__(self) -> int:
        return len(self.undiac)

    def __getitem__(self, idx: int):
        # Encoder input: just the characters
        src = torch.tensor(self.vocab.encode(self.undiac[idx]), dtype=torch.long)
        # Decoder input: SOS + characters (teacher forcing)
        tgt_in = torch.tensor(
            [self.vocab.char2idx[self.vocab.SOS]] + self.vocab.encode(self.diac[idx]), 
            dtype=torch.long
        )
        # Decoder target: characters + EOS
        tgt_out = torch.tensor(
            self.vocab.encode(self.diac[idx]) + [self.vocab.char2idx[self.vocab.EOS]], 
            dtype=torch.long
        )
        return src, tgt_in, tgt_out

def collate_fn(batch):
    srcs, tgt_ins, tgt_outs = zip(*batch)
    src_lens = torch.tensor([len(s) for s in srcs])
    srcs_padded = pad_sequence(srcs, batch_first=True, padding_value=0)
    tgt_ins_padded = pad_sequence(tgt_ins, batch_first=True, padding_value=0)
    tgt_outs_padded = pad_sequence(tgt_outs, batch_first=True, padding_value=0)
    return srcs_padded, tgt_ins_padded, tgt_outs_padded, src_lens

# Train/val split
random.seed(42)
indices = list(range(len(undiacritized)))
random.shuffle(indices)

split = int(0.95 * len(indices))
train_idx = indices[:split]
val_idx = indices[split:]

train_dataset = DiacritizationDataset(
    [undiacritized[i] for i in train_idx],
    [diacritized[i] for i in train_idx],
    vocab
)
val_dataset = DiacritizationDataset(
    [undiacritized[i] for i in val_idx],
    [diacritized[i] for i in val_idx],
    vocab
)

BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Step 5: Define the Encoder-Decoder model with Attention
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers: int = 2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        # Project bidirectional hidden states to decoder size
        self.fc_h = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_c = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, x):
        emb = self.embedding(x)
        outputs, (h, c) = self.lstm(emb)
        # Combine bidirectional hidden states
        h = torch.tanh(self.fc_h(torch.cat([h[-2], h[-1]], dim=1)))
        c = torch.tanh(self.fc_c(torch.cat([c[-2], c[-1]], dim=1)))
        return outputs, h.unsqueeze(0), c.unsqueeze(0)

class Attention(nn.Module):
    def __init__(self, enc_dim: int, dec_dim: int):
        super().__init__()
        self.attn = nn.Linear(enc_dim + dec_dim, dec_dim)
        self.v = nn.Linear(dec_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [1, batch, dec_dim]
        # encoder_outputs: [batch, src_len, enc_dim]
        batch_size, src_len, _ = encoder_outputs.shape
        
        hidden = hidden.squeeze(0).unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = self.v(energy).squeeze(2)
        
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        
        return F.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, enc_dim: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.attention = Attention(enc_dim, hidden_dim)
        self.lstm = nn.LSTMCell(embed_dim + enc_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim + enc_dim + embed_dim, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, input_char, hidden, cell, encoder_outputs):
        # input_char: [batch]
        emb = self.dropout(self.embedding(input_char))  # [batch, embed]
        
        attn_weights = self.attention(hidden.unsqueeze(0), encoder_outputs)  # [batch, src_len]
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # [batch, enc_dim]
        
        lstm_input = torch.cat([emb, context], dim=1)
        hidden, cell = self.lstm(lstm_input, (hidden, cell))
        
        output = self.fc(torch.cat([hidden, context, emb], dim=1))
        return output, hidden, cell

class Seq2SeqDiacritizer(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int = 64, hidden_dim: int = 128, num_layers: int = 2):
        super().__init__()
        self.encoder = Encoder(vocab_size, embed_dim, hidden_dim, num_layers)
        self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, hidden_dim * 2)  # enc is bidirectional

    def forward(self, src, tgt_in, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt_in.shape
        vocab_size = self.decoder.fc.out_features
        
        # Encode
        encoder_outputs, h, c = self.encoder(src)
        h = h.squeeze(0)
        c = c.squeeze(0)
        
        # Decode
        outputs = torch.zeros(batch_size, tgt_len, vocab_size, device=src.device)
        input_char = tgt_in[:, 0]  # SOS token
        
        for t in range(tgt_len):
            output, h, c = self.decoder(input_char, h, c, encoder_outputs)
            outputs[:, t] = output
            
            if t < tgt_len - 1:
                teacher_force = torch.rand(1).item() < teacher_forcing_ratio
                input_char = tgt_in[:, t + 1] if teacher_force else output.argmax(1)
        
        return outputs

# Create model
EMBED_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 2

model = Seq2SeqDiacritizer(
    vocab_size=len(vocab),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Estimated model size: {total_params * 4 / 1024 / 1024:.1f} MB")

In [None]:
# Step 6: Training loop with teacher forcing
import time

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

NUM_EPOCHS = 10
best_word_acc = 0
best_model_state = None

print("Starting training...\n")
start_time = time.time()

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start = time.time()
    
    # Train
    model.train()
    train_loss = 0
    teacher_forcing_ratio = max(0.5, 1.0 - epoch * 0.05)  # Decay from 1.0 to 0.5
    
    for batch_idx, (src, tgt_in, tgt_out, _) in enumerate(train_loader):
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        
        optimizer.zero_grad()
        logits = model(src, tgt_in, teacher_forcing_ratio)
        
        # Flatten for loss: [batch * seq, vocab] vs [batch * seq]
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()
        
        if (batch_idx + 1) % 500 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    train_loss /= len(train_loader)
    
    # Evaluate (with no teacher forcing)
    model.eval()
    val_loss = 0
    total_words = correct_words = 0
    
    with torch.no_grad():
        for src, tgt_in, tgt_out, src_lens in val_loader:
            src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
            
            logits = model(src, tgt_in, teacher_forcing_ratio=0.0)  # No teacher forcing
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
            val_loss += loss.item()
            
            preds = logits.argmax(dim=-1)
            
            # Calculate word accuracy
            for i in range(len(src)):
                pred_seq = preds[i].cpu().tolist()
                tgt_seq = tgt_out[i].cpu().tolist()
                
                pred_text = vocab.decode(pred_seq)
                tgt_text = vocab.decode(tgt_seq)
                
                for pw, tw in zip(pred_text.split(), tgt_text.split()):
                    total_words += 1
                    if pw == tw:
                        correct_words += 1
    
    val_loss /= len(val_loader)
    word_acc = correct_words / total_words if total_words > 0 else 0
    
    scheduler.step(val_loss)
    epoch_time = time.time() - epoch_start
    
    print(f"Epoch {epoch}/{NUM_EPOCHS} ({epoch_time:.0f}s) | "
          f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"Word Acc: {word_acc:.1%} | TF: {teacher_forcing_ratio:.0%}")
    
    if word_acc > best_word_acc:
        best_word_acc = word_acc
        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        print(f"  ↑ New best model!")

total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.1f} minutes")
print(f"Best word accuracy: {best_word_acc:.1%}")

In [None]:
# Step 7: Save the model
checkpoint = {
    "model_type": "seq2seq",
    "model_state_dict": best_model_state,
    "vocab": vocab.char2idx,
    "config": {
        "vocab_size": len(vocab),
        "embed_dim": EMBED_DIM,
        "hidden_dim": HIDDEN_DIM,
        "num_layers": NUM_LAYERS,
    },
    "metrics": {
        "word_acc": best_word_acc,
    },
}

torch.save(checkpoint, "yoruba_seq2seq_diacritizer.pt")

import os
model_size = os.path.getsize("yoruba_seq2seq_diacritizer.pt") / (1024 * 1024)
print(f"Model saved: yoruba_seq2seq_diacritizer.pt ({model_size:.1f} MB)")

In [None]:
# Step 8: Test the model with autoregressive decoding
model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
model.eval()

def diacritize_seq2seq(text: str, max_len: int = 300) -> str:
    """Diacritize text using autoregressive decoding."""
    with torch.no_grad():
        # Encode
        src = torch.tensor([vocab.encode(text)], dtype=torch.long).to(device)
        encoder_outputs, h, c = model.encoder(src)
        h = h.squeeze(0)
        c = c.squeeze(0)
        
        # Decode autoregressively
        output_chars = []
        input_char = torch.tensor([vocab.char2idx[vocab.SOS]], device=device)
        
        for _ in range(max_len):
            output, h, c = model.decoder(input_char, h, c, encoder_outputs)
            pred_idx = output.argmax(1).item()
            
            if pred_idx == vocab.char2idx[vocab.EOS]:
                break
            if pred_idx not in (vocab.char2idx[vocab.PAD], vocab.char2idx[vocab.SOS]):
                output_chars.append(vocab.idx2char.get(pred_idx, ""))
            
            input_char = torch.tensor([pred_idx], device=device)
        
        return "".join(output_chars)

# Test examples
test_sentences = [
    "Ojo dara pupo",
    "E ku ishe o",
    "Mo fe ran re",
    "Bawo ni o se wa",
    "Olorun a bukun fun e",
]

print("\nTest Results:")
print("=" * 60)
for sent in test_sentences:
    result = diacritize_seq2seq(sent)
    print(f"Input:  {sent}")
    print(f"Output: {result}")
    print("-" * 40)

In [None]:
# Step 9: Download the model (Colab)
try:
    from google.colab import files
    files.download("yoruba_seq2seq_diacritizer.pt")
    print("Download started! Check your downloads folder.")
except:
    print("Not running on Colab. Model saved locally.")