In [1]:
import os, json, random

PROJECT_ROOT = "/Users/jasleenkaur/Desktop/translit-consistency"
os.chdir(PROJECT_ROOT)

with open("data/processed/aligned_pairs_high_conf.json", encoding="utf-8") as f:
    raw_pairs = json.load(f)

print("Total raw pairs:", len(raw_pairs))

Total raw pairs: 41044


In [2]:
import re

def is_valid_pair(en, hi):
    # Hindi must be Devanagari only
    if not re.fullmatch(r"[ऀ-ॿ।]+", hi):
        return False

    # Avoid extreme length mismatch
    if abs(len(en) - len(hi)) > 10:
        return False

    # Too many halants = noisy
    if hi.count("्") > 3:
        return False

    # Remove junk repetitions
    if len(hi) - len(set(hi)) > 6:
        return False

    return True


pairs = []
for en, hi, _ in raw_pairs:
    en = en.lower()
    if is_valid_pair(en, hi):
        pairs.append((en, hi))

print("Filtered pairs:", len(pairs))

Filtered pairs: 40982


In [3]:
random.seed(42)
random.shuffle(pairs)

n = len(pairs)
train = pairs[:int(0.8 * n)]
val = pairs[int(0.8 * n):int(0.9 * n)]
test = pairs[int(0.9 * n):]

print("Train: ", len(train))
print("Val: ", len(val))
print("Test: ", len(test))

Train:  32785
Val:  4098
Test:  4099


In [4]:
def build_char_set(words):
    chars = set()
    for w in words:
        chars.update(w)
    return sorted(chars)

en_words = [en for en, hi in pairs]
hi_words = [hi for en, hi in pairs]

en_chars = build_char_set(en_words)
hi_chars = build_char_set(hi_words)

print("English chars: ", len(en_chars))
print("Hindi chars: ", len(hi_chars))

English chars:  61
Hindi chars:  80


In [5]:
PAD = "<pad>"
SOS = "<s>"
EOS = "</s>"

def build_vocab_with_tokens(chars):
    vocab = [PAD, SOS, EOS] + chars
    stoi = {c: i for i, c in enumerate(vocab)}
    itos = {i: c for c, i in stoi.items()}
    return vocab, stoi, itos

en_vocab, en_stoi, en_itos = build_vocab_with_tokens(en_chars)
hi_vocab, hi_stoi, hi_itos = build_vocab_with_tokens(hi_chars)

In [6]:
def normalize_english(word):
    return {
        "delhi": "dilli",
        "bangalore": "bangalor",
        "bengaluru": "bangalor",
        "maharashta": "maharashtra",
    }.get(word.lower(), word.lower())

In [7]:
def encode(word, stoi):
    if stoi is en_stoi:
        word = normalize_english(word)
    return [stoi[SOS]] + [stoi[c] for c in word] + [stoi[EOS]]

In [8]:
train_enc = [(encode(en, en_stoi), encode(hi, hi_stoi)) for en, hi in train]
val_enc   = [(encode(en, en_stoi), encode(hi, hi_stoi)) for en, hi in val]
test_enc  = [(encode(en, en_stoi), encode(hi, hi_stoi)) for en, hi in test]

print("Encoded train example:")
print(train_enc[0])

Encoded train example:
([1, 39, 21, 40, 45, 21, 27, 38, 28, 21, 2], [1, 51, 34, 67, 44, 54, 21, 67, 45, 52, 2])


In [9]:
def pad(seq, max_len, pad_id):
    return seq + [pad_id] * (max_len - len(seq))

max_en_len = max(len(x[0]) for x in train_enc)
max_hi_len = max(len(x[1]) for x in train_enc)

train_pad = [
    (pad(src, max_en_len, en_stoi[PAD]),
     pad(tgt, max_hi_len, hi_stoi[PAD]))
    for src, tgt in train_enc
]

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device: ", device)

Using device:  mps


In [11]:
train_src = torch.tensor(
    [src for src, tgt in train_pad],
    dtype = torch.long
).to(device)

train_tgt = torch.tensor(
    [tgt for src, tgt in train_pad],
    dtype = torch.long
).to(device)

In [12]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx = 0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first = True)

    def forward(self, x):
        emb = self.embedding(x)
        outputs, (h, c) = self.lstm(emb)
        return outputs, h, c

In [13]:
class LuongAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.scale = 1.0 / (hidden_dim ** 0.5)

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: (B, H)
        encoder_outputs: (B, src_len, H)
        """
        # (B, src_len)
        scores = torch.bmm(
            encoder_outputs,
            decoder_hidden.unsqueeze(2)
        ).squeeze(2)

        attn_weights = torch.softmax(scores * self.scale, dim=1)

        # (B, H)
        context = torch.bmm(
            attn_weights.unsqueeze(1),
            encoder_outputs
        ).squeeze(1)

        return context, attn_weights

In [14]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.attention = LuongAttention(hidden_dim)
        self.lstm = nn.LSTM(embed_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, h, c, encoder_outputs):
        # x: (B, 1)
        emb = self.embedding(x)  # (B, 1, E)

        # Attention
        context, _ = self.attention(h[-1], encoder_outputs)  # (B, H)
        context = context.unsqueeze(1)  # (B, 1, H)

        # LSTM
        lstm_input = torch.cat([emb, context], dim=2)  # (B, 1, E+H)
        output, (h, c) = self.lstm(lstm_input, (h, c))  # output: (B, 1, H)

        logits = self.fc(output.squeeze(1))  # (B, vocab)
        return logits, h, c

In [15]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        """
        src: (B, src_len)
        tgt: (B, tgt_len)
        """
        batch_size, tgt_len = tgt.shape
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(src.device)

        encoder_outputs, h, c = self.encoder(src)

        input_tok = tgt[:, 0].unsqueeze(1)

        for t in range(1, tgt_len):
            logits, h, c = self.decoder(input_tok, h, c, encoder_outputs)
            outputs[:, t] = logits

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = logits.argmax(1).unsqueeze(1)  

            input_tok = tgt[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

In [16]:
EMBED_DIM = 128
HIDDEN_DIM = 256

encoder = Encoder(len(en_vocab), EMBED_DIM, HIDDEN_DIM).to(device)
decoder = Decoder(len(hi_vocab), EMBED_DIM, HIDDEN_DIM).to(device)

model = Seq2Seq(encoder, decoder, hi_stoi[PAD]).to(device)

In [17]:
optimizer = optim.Adam(model.parameters(), lr = 0.0005)

In [18]:
criterion = nn.CrossEntropyLoss(
    ignore_index=hi_stoi[PAD],
    label_smoothing=0.1
)

In [19]:
def pad_dataset(enc_data, max_en_len, max_hi_len):
    return[
        (
            pad(src, max_en_len, en_stoi[PAD]),
            pad(tgt, max_hi_len, hi_stoi[PAD])
        )
        for src, tgt in enc_data
    ]

val_pad = pad_dataset(val_enc, max_en_len, max_hi_len)

In [20]:
EPOCHS = 60
BATCH_SIZE = 64

def train_epoch(model, data, optimizer, criterion, teacher_forcing_ratio):
    model.train()
    total_loss = 0

    random.shuffle(data)

    for i in range(0, len(data), BATCH_SIZE):
        batch = data[i:i+BATCH_SIZE]
        src = torch.tensor([x[0] for x in batch], dtype=torch.long).to(device)
        tgt = torch.tensor([x[1] for x in batch], dtype=torch.long).to(device)

        optimizer.zero_grad()

        output = model(src, tgt, teacher_forcing_ratio=teacher_forcing_ratio)

        loss = criterion(
            output[:, 1:].reshape(-1, output.size(-1)),
            tgt[:, 1:].reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / (len(data) // BATCH_SIZE + 1)

In [21]:
def evaluate(model, data, criterion):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i in range(0, len(data), BATCH_SIZE):
            batch = data[i:i+BATCH_SIZE]
            src = torch.tensor([x[0] for x in batch], dtype=torch.long).to(device)
            tgt = torch.tensor([x[1] for x in batch], dtype=torch.long).to(device)

            output = model(src, tgt, teacher_forcing_ratio=0.0)

            loss = criterion(
                output[:, 1:].reshape(-1, output.size(-1)),
                tgt[:, 1:].reshape(-1)
            )

            total_loss += loss.item()

    return total_loss / (len(data) // BATCH_SIZE + 1)

In [29]:
COMMON_FIXES = {
    "िि": "ि",
    "ाा": "ा",
    "ुु": "ु",
    "ूू": "ू",
    "ेे": "े",
    "ोो": "ो"
}

def normalize_hindi(text):
    for b, g in COMMON_FIXES.items():
        text = text.replace(b, g)
    return text

def postprocess_hindi(text):
    # Collapse duplicated matras (िि → ि, etc.)
    for m in ["ि", "ी", "ा", "ु", "ू", "े", "ो"]:
        text = text.replace(m + m, m)

    # Fix duplicated final consonant (ष्ट्ट → ष्ट)
    if len(text) >= 2 and text[-1] == text[-2]:
        text = text[:-1]

    # PROTECT common Hindi suffixes
    protected_suffixes = (
        "स्थान", "पुर", "नगर", "गंज", "गढ़", "पुरम"
    )
    for suf in protected_suffixes:
        if text.endswith(suf):
            return text   # DO NOTHING further

    # Trim hallucinated trailing junk ONLY if long
    if len(text) >= 7:
        # remove trailing vowels like "जो", "यी", "ऊ"
        if text[-1] in {"ो", "ू", "ी"}:
            text = text[:-1]

        # remove trailing filler consonants
        if text[-1] in {"य", "र"}:
            text = text[:-1]

    return text

def fix_common_seq2seq_errors(text):
    # Remove hallucinated trailing syllables
    if text.endswith("का") and len(text) > 5:
        text = text[:-2]

    # Fix vowel echo (रीरे → री, रिरे → री)
    text = text.replace("रीरे", "री")
    text = text.replace("रिरे", "री")

    # Fix double syllable drift (ल्ली → ली)
    text = text.replace("ल्लि", "ल्ली")
    text = text.replace("मिल्लि", "दिल्ली")

    # Remove dangling halant at end
    if text.endswith("्"):
        text = text[:-1]

    return text

In [33]:
def char_accuracy(pred, gold):
    correct = 0
    total = max(len(pred), len(gold))

    for p, g in zip(pred, gold):
        if p == g:
            correct += 1

    return correct / total if total > 0 else 0.0

In [34]:
def beam_transliterate(model, word, beam_width=4, max_len=40):
    model.eval()

    src = torch.tensor([encode(word.lower(), en_stoi)], dtype=torch.long).to(device)
    with torch.no_grad():
        encoder_outputs, h, c = model.encoder(src)

    beams = [([hi_stoi[SOS]], 0.0, h, c)]

    for _ in range(max_len):
        new_beams = []
        for seq, score, h, c in beams:
            if seq[-1] == hi_stoi[EOS]:
                new_beams.append((seq, score, h, c))
                continue

            input_tok = torch.tensor([[seq[-1]]], device=device)
            with torch.no_grad():
                logits, h_new, c_new = model.decoder(input_tok, h, c, encoder_outputs)

            log_probs = torch.log_softmax(logits[0], dim=-1)
            topk = torch.topk(log_probs, beam_width)

            for idx, val in zip(topk.indices, topk.values):
                next_char = hi_itos[idx.item()]

                penalty = 0.0
                if seq.count(idx.item()) >= 2:
                    penalty = 1.5
                    
                new_beams.append(
                    (seq + [idx.item()],score + val.item() - penalty,h_new,c_new)
                )

        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
        if all(seq[-1] == hi_stoi[EOS] for seq, _, _, _ in beams):
            break

    best = beams[0][0]
    out = "".join(
        hi_itos[i] for i in best
        if i not in {hi_stoi[SOS], hi_stoi[EOS], hi_stoi[PAD]}
    )
    out = normalize_hindi(out)
    out = postprocess_hindi(out)
    out = fix_common_seq2seq_errors(out)

    if len(out) > 2 * len(word):
        out = out[:2 * len(word)]
    
    return out

In [35]:
def evaluate_char_accuracy(model, data, limit=None):
    scores = []

    for i, (en, hi) in enumerate(data):
        if limit and i >= limit:
            break

        pred = beam_transliterate(model, en)
        scores.append(char_accuracy(pred, hi))

    return sum(scores) / len(scores)

In [38]:
# for epoch in range(EPOCHS):
#     teacher_forcing_ratio = max(0.25, 0.6 * (0.97 ** epoch))
#     # teacher_forcing_ratio = 0.5
#     train_loss = train_epoch(model, train_pad, optimizer, criterion, teacher_forcing_ratio)
#     val_loss = evaluate(model, val_pad, criterion)

#     print(f"Epoch {epoch+1}/{EPOCHS}")
#     print(f"Train Loss: {train_loss:.4f}")
#     print(f"Val Loss: {val_loss:.4f}")
#     print("-" * 40)
num_epochs = EPOCHS

best_val = 0.0
patience_counter = 0
PATIENCE = 6

for epoch in range(num_epochs):

    train_loss = train_epoch(
        model,
        train_pad,
        optimizer,
        criterion,
        teacher_forcing_ratio
    )

    val_char_acc = evaluate_char_accuracy(model, val)

    print(f"Epoch {epoch+1}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Char Accuracy: {val_char_acc * 100:.2f}%")

    # Early stopping logic
    if val_char_acc > best_val:
        best_val = val_char_acc
        patience_counter = 0
        torch.save(model.state_dict(), "seq2seq_best.pt")
        print("New best model saved")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{PATIENCE})")

    if patience_counter >= PATIENCE:
        print("Early stopping triggered")
        break
        
    print("-" * 40)

Epoch 1
Train Loss: 1.4436
Val Char Accuracy: 43.85%
New best model saved
----------------------------------------
Epoch 2
Train Loss: 1.4005
Val Char Accuracy: 42.49%
No improvement (1/6)
----------------------------------------
Epoch 3
Train Loss: 1.3678
Val Char Accuracy: 44.83%
New best model saved
----------------------------------------
Epoch 4
Train Loss: 1.3354
Val Char Accuracy: 45.25%
New best model saved
----------------------------------------
Epoch 5
Train Loss: 1.3108
Val Char Accuracy: 45.03%
No improvement (1/6)
----------------------------------------
Epoch 6
Train Loss: 1.2949
Val Char Accuracy: 45.11%
No improvement (2/6)
----------------------------------------
Epoch 7
Train Loss: 1.2754
Val Char Accuracy: 43.66%
No improvement (3/6)
----------------------------------------
Epoch 8
Train Loss: 1.2556
Val Char Accuracy: 44.25%
No improvement (4/6)
----------------------------------------
Epoch 9
Train Loss: 1.2398
Val Char Accuracy: 45.29%
New best model saved
------

In [39]:
model.load_state_dict(torch.load("seq2seq_best.pt"))
model.eval()
print("Best Seq2Seq model loaded")

Best Seq2Seq model loaded


  model.load_state_dict(torch.load("seq2seq_best.pt"))


In [40]:
tests = [
    "Delhi",
    "Kolkata",
    "Bangalore",
    "Rajasthan",
    "Chandrakala",
    "Vishnupuram",
    "Maharashta",
    "Kaveri"
]

for w in tests:
    print("Input :", w)
    print("Beam:  ", beam_transliterate(model, w))
    # print("Greedy: ", greedy_transliterate(model,w))
    print("-" * 30)

Input : Delhi
Beam:   दिल्ली
------------------------------
Input : Kolkata
Beam:   कोलकत
------------------------------
Input : Bangalore
Beam:   बंगलोरो
------------------------------
Input : Rajasthan
Beam:   राजस्थान
------------------------------
Input : Chandrakala
Beam:   चंद्रकलाल
------------------------------
Input : Vishnupuram
Beam:   विषुणपुरम
------------------------------
Input : Maharashta
Beam:   मारहष्ट
------------------------------
Input : Kaveri
Beam:   केवीरी
------------------------------


In [41]:
def evaluate_char_accuracy(model, data):
    model.eval()
    scores = []

    for en, hi in data:
        pred = beam_transliterate(model, en)
        scores.append(char_accuracy(pred, hi))

    return sum(scores) / len(scores)

In [42]:
val_acc = evaluate_char_accuracy(model, val)
test_acc = evaluate_char_accuracy(model, test)

print("Validation Char Accuracy: ", round(val_acc * 100, 2), "%")
print("Test Char Accuracy: ", round(test_acc * 100, 2), "%")

Validation Char Accuracy:  45.29 %
Test Char Accuracy:  45.55 %


In [44]:
torch.save({
    "model_state": model.state_dict(),
    "en_stoi": en_stoi,
    "hi_stoi": hi_stoi,
    "hi_itos": hi_itos,
    "PAD": PAD,
    "SOS": SOS,
    "EOS": EOS,
    "EMBED_DIM": EMBED_DIM,
    "HIDDEN_DIM": HIDDEN_DIM
}, "seq2seq_best.pt")

print("Model saved as seq2seq_best.pt")

Model saved as seq2seq_best.pt
