# 🔡 Aksharantar Transliteration — LSTM + Bahdanau Attention (GPU-ready)
**Colab-ready notebook**: train and evaluate a character-level Seq2Seq transliteration model (Hindi example).
This notebook is prepared to run on Google Colab with GPU (recommended). Follow the cells in order.


## Notebook overview
Steps included:
1. Verify GPU and install dependencies  
2. Create folders and upload dataset (or mount Drive)  
3. Define utility functions and dataset class (`utils`)  
4. Define model components (`encoder`, `decoder`, `seq2seq`)  
5. Training loop (with progress printing)  
6. Save checkpoint and run inference  
7. Small evaluation (sample predictions)

**Dataset expectation**:  
Place `hin_train.csv` and `hin_valid.csv` into `data/aksharantar_sampled/hin/`.  
The notebook also contains instructions to upload files manually or mount Google Drive.


In [None]:
# Verify GPU availability and install dependencies (PyTorch is usually preinstalled on Colab)
!nvidia-smi || true
import torch
print('Torch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))

# Install/update required packages (quiet)
!pip install --quiet --upgrade pip
!pip install --quiet torch torchvision torchaudio pandas tqdm nbformat
print('Dependencies installed (or already present).')

In [None]:
# Create project folders
import os
os.makedirs('/content/data/aksharantar_sampled/hin', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)
os.makedirs('/content/checkpoints', exist_ok=True)
print('Folders created under /content')

### Upload dataset files
Use the left-side **Files** panel in Colab and click **Upload** to upload:
- `hin_train.csv`
- `hin_valid.csv`

Upload them into `/content/data/aksharantar_sampled/hin/`.

Alternatively, mount your Google Drive and copy files from Drive:
```python
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/path/to/hin_train.csv /content/data/aksharantar_sampled/hin/
!cp /content/drive/MyDrive/path/to/hin_valid.csv /content/data/aksharantar_sampled/hin/
```

In [None]:
# utils: data loading, vocab, dataset, collate
import os, pandas as pd, torch
from collections import Counter
from torch.utils.data import Dataset

def load_language_pairs(data_dir, lang):
    base = os.path.join(data_dir, lang)
    train_path = os.path.join(base, f"{lang}_train.csv")
    valid_path = os.path.join(base, f"{lang}_valid.csv")
    train_df = pd.read_csv(train_path, header=0)
    valid_df = pd.read_csv(valid_path, header=0)
    src_col, tgt_col = train_df.columns[:2]
    print(f"✅ Detected columns: source='{src_col}', target='{tgt_col}'")
    train_pairs = list(zip(train_df[src_col].astype(str), train_df[tgt_col].astype(str)))
    valid_pairs = list(zip(valid_df[src_col].astype(str), valid_df[tgt_col].astype(str)))
    print(f"✅ Loaded {len(train_pairs)} train / {len(valid_pairs)} valid samples")
    return train_pairs, valid_pairs

def build_vocab_from_pairs(pairs):
    counter = Counter()
    for s,t in pairs:
        counter.update(list(s))
        counter.update(list(t))
    tokens = ['<pad>','<sos>','<eos>','<unk>']
    stoi = {tok:i for i,tok in enumerate(tokens)}
    idx = len(stoi)
    for ch in sorted(counter.keys()):
        if ch not in stoi:
            stoi[ch] = idx
            idx += 1
    itos = {i:c for c,i in stoi.items()}
    return stoi, itos

class TransliterationDataset(Dataset):
    def __init__(self, pairs, src_stoi, tgt_stoi, max_len=30):
        self.pairs = pairs
        self.src_stoi = src_stoi
        self.tgt_stoi = tgt_stoi
        self.max_len = max_len
    def __len__(self):
        return len(self.pairs)
    def encode(self, text, stoi):
        ids = [stoi.get(ch, stoi['<unk>']) for ch in text][:self.max_len-2]
        ids = [stoi['<sos>']] + ids + [stoi['<eos>']]
        if len(ids) < self.max_len:
            ids += [stoi['<pad>']] * (self.max_len - len(ids))
        return torch.tensor(ids, dtype=torch.long)
    def __getitem__(self, idx):
        s,t = self.pairs[idx]
        return self.encode(s, self.src_stoi), self.encode(t, self.tgt_stoi)

def collate_fn(batch):
    srcs, tgts = zip(*batch)
    srcs = torch.stack(srcs)
    tgts = torch.stack(tgts)
    return srcs, tgts

print('utils loaded')

In [None]:
# models: encoder, attention, decoder, seq2seq
import torch, torch.nn as nn
class EncoderRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, cell_type='lstm'):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.cell_type = cell_type.lower()
        if self.cell_type == 'gru':
            self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        else:
            self.rnn = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
    def forward(self, x):
        emb = self.embedding(x)
        outputs, hidden = self.rnn(emb)
        return outputs, hidden

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.W1 = nn.Linear(hidden_dim, hidden_dim)
        self.W2 = nn.Linear(hidden_dim, hidden_dim)
        self.V = nn.Linear(hidden_dim, 1)
    def forward(self, dec_hidden, enc_outputs):
        dec_last = dec_hidden[0][-1].unsqueeze(1) if isinstance(dec_hidden, tuple) else dec_hidden[-1].unsqueeze(1)
        score = torch.tanh(self.W1(enc_outputs) + self.W2(dec_last))
        score = self.V(score).squeeze(-1)
        attn_weights = torch.softmax(score, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), enc_outputs).squeeze(1)
        return context, attn_weights

class DecoderRNN(nn.Module):
    def __init__(self, output_dim, hidden_dim, cell_type='lstm', use_attention=True):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, hidden_dim)
        self.cell_type = cell_type.lower()
        self.use_attention = use_attention
        if use_attention:
            self.attention = BahdanauAttention(hidden_dim)
            self.rnn = nn.LSTM(hidden_dim + hidden_dim, hidden_dim, batch_first=True)
        else:
            self.rnn = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward(self, input_tok, hidden, enc_outputs):
        emb = self.embedding(input_tok).unsqueeze(1)
        if self.use_attention:
            h = hidden[0] if isinstance(hidden, tuple) else hidden
            context, attn = self.attention(h, enc_outputs)
            rnn_input = torch.cat((emb, context.unsqueeze(1)), dim=2)
        else:
            rnn_input = emb
        out, hidden = self.rnn(rnn_input, hidden)
        pred = self.fc(out.squeeze(1))
        return pred, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch, seq = src.size()
        tgt_len = tgt.size(1)
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch, tgt_len, vocab_size).to(self.device)
        enc_outputs, hidden = self.encoder(src)
        input_tok = tgt[:,0]
        for t in range(1, tgt_len):
            pred, hidden = self.decoder(input_tok, hidden, enc_outputs)
            outputs[:,t,:] = pred
            top1 = pred.argmax(1)
            input_tok = tgt[:,t] if torch.rand(1).item() < teacher_forcing_ratio else top1
        return outputs

print('models loaded')

In [None]:
# Training cell (adjust hyperparams as needed)
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

DATA_DIR = '/content/data/aksharantar_sampled'
LANG = 'hin'
MAX_LEN = 30
BATCH = 128
HID = 256
EPOCHS = 5  # increase to 25+ for final training

train_pairs, valid_pairs = load_language_pairs(DATA_DIR, LANG)
pairs = train_pairs + valid_pairs
src_stoi, src_itos = build_vocab_from_pairs(pairs)
tgt_stoi, tgt_itos = build_vocab_from_pairs(pairs)

train_ds = TransliterationDataset(train_pairs, src_stoi, tgt_stoi, max_len=MAX_LEN)
valid_ds = TransliterationDataset(valid_pairs, src_stoi, tgt_stoi, max_len=MAX_LEN)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_size=BATCH, shuffle=False, collate_fn=collate_fn)

enc = EncoderRNN(len(src_stoi), HID, cell_type='lstm').to(device)
dec = DecoderRNN(len(tgt_stoi), HID, cell_type='lstm', use_attention=True).to(device)
model = Seq2Seq(enc, dec, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=src_stoi['<pad>'])

for epoch in range(1, EPOCHS+1):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        out = model(src, tgt)
        loss = criterion(out[:,1:].reshape(-1, out.shape[-1]), tgt[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg = total_loss / len(train_loader)
    print(f'Epoch {epoch}/{EPOCHS} — Train Loss: {avg:.4f}')
    # quick validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in valid_loader:
            src, tgt = src.to(device), tgt.to(device)
            out = model(src, tgt, teacher_forcing_ratio=0.0)
            val_loss += criterion(out[:,1:].reshape(-1, out.shape[-1]), tgt[:,1:].reshape(-1)).item()
    print(f'Val Loss: {val_loss / len(valid_loader):.4f}')
    torch.save(model.state_dict(), f'/content/checkpoints/{LANG}_epoch{epoch}.pt')
print('Training complete')

In [None]:
# Simple inference example (after training)
# Build reverse mappings
src_itos = {i:c for c,i in src_stoi.items()}
tgt_itos = {i:c for c,i in tgt_stoi.items()}

def predict_word(word, max_len=30):
    model.eval()
    with torch.no_grad():
        ids = [src_stoi.get(ch, src_stoi['<unk>']) for ch in word][:max_len-2]
        src_tensor = torch.tensor([ [src_stoi['<sos>']] + ids + [src_stoi['<eos>']] + [src_stoi['<pad>']]*(max_len - (len(ids)+2)) ], dtype=torch.long).to(device)
        enc_out, hidden = enc(src_tensor)
        input_tok = torch.tensor([tgt_stoi['<sos>']]).to(device)
        out_s = ''
        for _ in range(max_len):
            pred, hidden = dec(input_tok, hidden, enc_out)
            tok = pred.argmax(1).item()
            if tok == tgt_stoi.get('<eos>', -1):
                break
            out_s += tgt_itos.get(tok, '')
            input_tok = torch.tensor([tok]).to(device)
    return out_s

# Try example (after training)
print('Sample:', predict_word('bindhya'))

In [None]:
# Save final notebook model to Drive if desired (mount first)
# from google.colab import drive
# drive.mount('/content/drive')
# !cp /content/checkpoints/hin_epoch25.pt /content/drive/MyDrive/
print('Use Drive copy commands to save checkpoints to your Drive')