In [1]:
import numpy as np
import torch
import pathlib
import re
import unicodedata
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn as nn
import sentencepiece as spm

from torch.utils.data import Dataset, DataLoader
from collections import Counter
from tqdm import tqdm
from pprint import pprint

torch.set_grad_enabled(True)
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x2028e875b10>

In [2]:
config = {
    'MAX_VOCAB_SIZE': 15000,
    'BATCH_SIZE': 8,
    'train_en_path': "./dataset/consolidated/train.en",
    'train_ne_path': "./dataset/consolidated/train.ne",
    'MAX_SEQ_LEN': 32,
    'BUFFER_SIZE': 1000,
    'UNITS': 256,
    'EPOCHS': 10,
    'DEVICE': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

In [3]:
with open(config['train_en_path'], 'r', encoding='utf-8') as f:
    en_lines = f.readlines()
    
with open(config['train_ne_path'], 'r', encoding='utf-8') as f:
    ne_lines = f.readlines()
    
context_en = np.array(en_lines)
target_ne = np.array(ne_lines)
sentences = np.array((context_en, target_ne))

In [4]:
# load the trained BPE tokenizers
sp_en = spm.SentencePieceProcessor(model_file = 'en_bpe_model.model')
sp_ne = spm.SentencePieceProcessor(model_file = 'ne_bpe_model.model')

In [5]:
sp_ne.piece_to_id('<unk>'), sp_ne.piece_to_id('<s>'), sp_ne.piece_to_id('</s>'), sp_ne.piece_to_id('<pad>')

(0, 1, 2, 3)

### Neural Machine Translation (EN - NE Translation Dataset)

In [6]:
class NMT_dataset(Dataset):
    def __init__(self, translation_pairs, src_tokenizer, tgt_tokenizer, max_seq_len, device = 'cpu'):
    
        self.translation_pairs = translation_pairs
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_seq_len = max_seq_len
        self.device = device

        # for convenience 
        self.sos_id = self.src_tokenizer.bos_id()
        self.eos_id = self.src_tokenizer.eos_id()
        self.pad_id = self.src_tokenizer.piece_to_id('<pad>') # sentence_piece uses pad_id (-1)
        self.oov_id = self.src_tokenizer.unk_id()

    def __len__(self):
        return self.translation_pairs.shape[-1]

    def __getitem__(self, idx):
        req_pair = self.translation_pairs[:, idx]
        
        # src_translation -> English; tgt_translation -> Nepali (for this case)
        src_translation, tgt_translation = req_pair

        context_tokens = self.src_tokenizer.encode(src_translation, out_type = int)
        target_tokens = self.tgt_tokenizer.encode(tgt_translation, out_type = int)

        # encoder input tokens
        encoder_input = (
            [self.sos_id] + 
            context_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(context_tokens) - 2) * [self.pad_id]
            )
        
        # pre-attention decoder input tokens
        pre_decoder_input = (
            [self.sos_id] + 
            target_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(target_tokens) - 2) * [self.pad_id] 
        )

        # post-attention decoder output tokens
        post_decoder_output = (
            target_tokens + 
            [self.eos_id] + 
            (self.max_seq_len - len(target_tokens) - 1) * [self.pad_id]
        )

        encoder_input_tensor = torch.tensor(encoder_input[:self.max_seq_len], dtype = torch.long).to(self.device)
        pre_decoder_input_tensor = torch.tensor(pre_decoder_input[:self.max_seq_len], dtype = torch.long).to(self.device)
        post_decoder_output_tensor = torch.tensor(post_decoder_output[:self.max_seq_len], dtype = torch.long).to(self.device)

        return encoder_input_tensor, pre_decoder_input_tensor, post_decoder_output_tensor

In [7]:
# train and val_dataset
is_train = np.random.uniform(size = (sentences.shape[-1],)) < 0.85
train_raw_set = sentences[:, is_train]
val_raw_set = sentences[:, ~is_train]

print(train_raw_set.shape, val_raw_set.shape)

train_dataset = NMT_dataset(train_raw_set, sp_en, sp_ne, config['MAX_SEQ_LEN'])
val_dataset = NMT_dataset(val_raw_set, sp_en, sp_ne, config['MAX_SEQ_LEN'])

(2, 128972) (2, 22965)


In [8]:
train_loader = DataLoader(
    train_dataset, 
    batch_size = config['BATCH_SIZE'], 
    shuffle = True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size = config['BATCH_SIZE'], 
    shuffle = False
)

In [9]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, units):
        super(Encoder, self).__init__()
        
        # input embedding
        self.embedding = nn.Embedding(
            num_embeddings = vocab_size,
            embedding_dim = units,
            padding_idx = 3
        )
        
        # bi-directional LSTM
        self.rnn = nn.LSTM(
            input_size = units,
            hidden_size = units,
            batch_first = True,
            bidirectional = True
        )

    def forward(self, context):
        x = self.embedding(context)
        x, _ = self.rnn(x)
        x = x[:, :, :self.rnn.hidden_size] + x[:, :, self.rnn.hidden_size:]

        return x
    
class CrossAttention(nn.Module):
    def __init__(self, units):

        super(CrossAttention, self).__init__()

        self.multihead_attn = nn.MultiheadAttention(
            embed_dim = units,  # the size of Q, K, V dims is the embedding dimension
            num_heads = 1,  
            batch_first = True  # (batch_size, sequence_length, embedding_dim) either (seq_len, batch_size, embedding_dim)
        )

        self.layernorm = nn.LayerNorm(units)
        
        # to accumulate the inputs from encoder output (encoder_in) and the pre-attention decoder (target_in)
        self.add = nn.ModuleList([nn.Linear(units, units) for _ in range(2)])

    def forward(self, context, target):
        # query is the target_in from the pre-attention decoder, while key/value are the context from the encoded context from the encoder
        attn_output, _ = self.multihead_attn(query = target, key = context, value = context)
        
        x = self.add[0](target) + self.add[1](attn_output)
        return self.layernorm(x)

class Decoder(nn.Module):
    def __init__(self, vocab_size, units):
        super(Decoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, units, padding_idx = 0)
        self.pre_attention_rnn = nn.LSTM(units, units, batch_first = True)
        self.attention = CrossAttention(units)
        self.post_attention_rnn = nn.LSTM(units, units, batch_first = True)
        self.output_layer = nn.Linear(units, vocab_size)

    def forward(self, context, target, state = None, return_state = False):
        
        x = self.embedding(target)
        
        # pre-attention-LSTM (decoder)
        if state is None:
            x, (hidden_state, cell_state) = self.pre_attention_rnn(x)
        else:
            x, (hidden_state, cell_state) = self.pre_attention_rnn(x, state)
        
        # cross-attention between pre-attention-LSTM output and the encoded context
        x = self.attention(context, x)
        
        # post-attention-LSTM (decoder)
        x, _ = self.post_attention_rnn(x)

        # last linear (dense) layer
        logits = self.output_layer(x)
        logits = F.log_softmax(logits, dim = -1)

        if return_state:
            return logits, (hidden_state, cell_state)

        return logits

class NMT_Translator(nn.Module):
    def __init__(self, vocab_size, units):
        super(NMT_Translator, self).__init__()
        
        self.encoder = Encoder(vocab_size, units)
        self.decoder = Decoder(vocab_size, units)
        
    def forward(self, context, target):
        encoded_context = self.encoder(context)
        logits = self.decoder(encoded_context, target) # this target here is the target_in (target_out is used for training)
        
        return logits # (batch_size, max_seq_len, vocab_size)

In [10]:
class MaskedAcc(nn.Module):
    def __init__(self):
        super(MaskedAcc, self).__init__()

    def forward(self, y_pred, y_true):
        y_pred = y_pred.argmax(dim = -1)

        mask = (y_true != 0).float()
        correct = (y_true == y_pred).float() * mask
        return correct.sum() / mask.sum()
    
class SparseCategoricalMaskedLoss(nn.Module):
    def __init__(self):
        super(SparseCategoricalMaskedLoss, self).__init__()
    
    def forward(self, y_pred, y_true):
        batch_size, seq_len, vocab_size = y_pred.shape
    
        # mask to avoid the padding sequence
        # for loss mask
        mask = (y_true != 0).float()

        y_pred = y_pred.view(-1, y_pred.size(-1))
        y_true = y_true.view(-1)
        
        loss = F.cross_entropy(y_pred, y_true.long(), reduction = 'none')

        loss = loss.view(batch_size, seq_len)
        loss *= mask

        return loss.sum() / mask.sum()

In [11]:
vocab_size = config['MAX_VOCAB_SIZE'] # for padding_idx
units = config['UNITS']

model = NMT_Translator(vocab_size, units)

optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss = nn.CrossEntropyLoss()
masked_acc_fn = MaskedAcc()

sparse_masked_loss_fn = SparseCategoricalMaskedLoss()

In [13]:
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(config['EPOCHS']):
    model.train()
    total_loss = 0
    total_acc = 0

    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['EPOCHS']}", leave = False)
    for batch in train_loader_tqdm:
        
        # encoder input, pre-attention-decoder input and post-attention-decoder target output
        context, target_in, target_out = batch

        optimizer.zero_grad()
        
        output = model(context, target_in)

        loss = sparse_masked_loss_fn(output.float(), target_out.float())
        acc = masked_acc_fn(output.float(), target_out.float())
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += acc.item()
        train_loader_tqdm.set_postfix(loss = loss.item(), accuracy = acc.item())

    avg_loss = total_loss / len(train_loader)
    avg_acc = total_acc / len(train_loader)

    print(f"Epoch {epoch + 1}/{config['EPOCHS']}")
    print(f"Training loss: {avg_loss:.4f}, accuracy: {avg_acc:.4f}")

    # Validation
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            context, target_in, target_out = batch
            output = model(context, target_in)
            loss = sparse_masked_loss_fn(output.float(), target_out.float())
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation loss: {avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= 3:
            print("Early stopping")
            break

# Load the best model
# model.load_state_dict(torch.load('best_model.pth'))


                                                                                                                       

KeyboardInterrupt: 

In [None]:
import torch

# Example embedding layer
embedding = torch.nn.Embedding(num_embeddings = vocab_size, embedding_dim = units)

# Example indices (replace with your actual input indices)
input_indices = torch.tensor([    1,  4979,   803,  3175, 14939,    54,   382,     5,  2137,   114,
          5162, 14940,   235,  7879,  8428,  1283, 10056,  6046,  2374, 14938,
           358,   144,     5,   873,  7676,    33,  3878,    28,  1851,   261,
          9110,  5308, -1])  # Notice 10000 is out of range

# Forward pass (use only valid indices)
output = embedding(input_indices)
