# Sample Transliteration Task
### Training the transformers for transliteration on a small sample of hindi to english

In [12]:
from tqdm.notebook import tqdm
import torch
from torch import nn
from transformer import Transformer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Dataset

In [3]:
dataset = {'रासविहारी': 'RASVIHARI', 'देवगन': 'DEVGAN', 'रोड': 'ROAD', 'शत्रुमर्दन': 'SHATRUMARDAN', 'महिजुबा': 'MAHIJUBA', 'सैबिन': 'SAIBIN', 'बिल': 'BILL', 'कॉस्बी': 'COSBY', 'रिश्ता': 'RISTA', 'कागज़': 'KAGAZ', 'का': 'KA', 'हातिम': 'HATIM', 'श्रीमयी': 'SRIMAYI', 'फरीहाह': 'FARIHAH', 'मैरीटाइम': 'MARITIME', 'म्युज़ियम': 'MUSIUM', 'ऑफ': 'OF', 'ग्रीस': 'GREECE', 'मंथन': 'MANTHAN', 'फ्रेंकोरशियन': 'FRANCORUSSIAN', 'वार': 'BAR', 'तन्मया': 'TANMYA', 'मल्ली': 'MALLI', 'केलीमुटु': 'KELIMUTU', 'मुटाटकर': 'MUTATAKAR', 'गंगा': 'GANGA', 'मैया': 'MAIYA', 'फरीदाह': 'FARIDAH', 'तहमीना': 'TAHMEENA', 'दुर्रानी': 'DURANII', 'डान्यूब': 'DANUBE', 'बलील': 'BALEEL'}

In [4]:
# Create english vocabulary
english_alphabets = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'

eng_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
for index, alphabet in enumerate(english_alphabets) :
    eng_vocab[alphabet] = index + 3

In [5]:
# Create hindi vocabulary
hin_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
for num in range(2304, 2436) :
    hin_vocab[chr(num)] = num - 2301

In [6]:
# Separate the dataset into independent lists
source, target = [], []
for keys, values in dataset.items():
    source.append(keys)
    target.append(values)

## Preprocessing

In [7]:
def encode_sequence(sequence, vocab, max_len):
    '''Encode a single sequence'''
    encoded_seq = [vocab['<sos>']]
    for char in sequence:
        encoded_seq.append(vocab[char])
    encoded_seq.append(vocab['<eos>'])

    if len(encoded_seq) < max_len:
        encoded_seq.extend([vocab['<pad>']] * (max_len - len(encoded_seq)))

    return torch.LongTensor(encoded_seq)

In [8]:
def encode(sequences, vocab):
    '''Preprocesses a list of sequences'''
    max_len = max([len(s) for s in sequences]) + 2

    input_sequences = []
    for seq in sequences:
        input_sequences.append(encode_sequence(seq, vocab, max_len))
    
    return torch.stack(input_sequences)

In [9]:
X = encode(source, hin_vocab).to(device)
y = encode(target, eng_vocab).to(device)

In [10]:
pad_idx = 0
src_vocab_size = len(hin_vocab)
trg_vocab_size = len(eng_vocab)

In [13]:
model = Transformer(src_vocab_size, trg_vocab_size,
                    pad_idx, pad_idx, device=device).to(device)

## Training

In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
epochs = 200

In [16]:
pbar = tqdm(total=epochs)
for i in range(epochs):

    # shift the target to the left so it predicts the last token
    out = model(X, y[:, :-1])
    out = out.reshape(-1, out.shape[2])
    # print('Output:', out.shape)

    target = y[:, 1:].reshape(-1)  # shift the labels to the right
    # print('Target:', target.shape)

    loss = criterion(out, target)
    # print(f"Loss {loss.item()}")
    loss.backward()

    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()

    pbar.set_postfix({'Loss': loss.item()})
    pbar.update(1)

  0%|          | 0/200 [00:00<?, ?it/s]

## Prediction

In [19]:
def decode_sequence(sequence, vocab):
    '''Decode integer encoding to text'''
    rev_vocab = {v: k for k, v in vocab.items()}
    decoded = ''
    for i in sequence:
        if i > 2:
            decoded += rev_vocab[i]
    
    return decoded

In [30]:
def predict(source):
    preds = [1]  # Intialize predictions with SOS
    max_gen_len = 15

    for i in range(max_gen_len):

        input_tensor = torch.tensor(preds).unsqueeze(0).to(device)
        # print('Input Tensor:', input_tensor)

        # Predict targets
        with torch.no_grad():
            out = model(source, input_tensor)
            # print('Raw Output Shape:', out.shape)

        # Get the last word with highest probability
        word_idx = out.argmax(dim=-1)[:, -1].item()
        # print('Next Word Index:', word_idx)

        # Append to outputs
        preds.append(word_idx)

        if word_idx == 2:  # If token is EOS then stop predicting
            break
    
    return preds

In [37]:
idx = 8
testX = X[idx].unsqueeze(0).to(device)
testY = y[idx].tolist()

In [38]:
preds = predict(testX)
print("Input:", decode_sequence(testX.squeeze(0).tolist(), hin_vocab))
print("Prediction:", decode_sequence(preds, eng_vocab))
print("Ground Truth:", decode_sequence(testY, eng_vocab))

Input: रिश्ता
Prediction: RISTA
Ground Truth: RISTA
