In [3]:
pip install -q datasets torch matplotlib sacrebleu

^C
[31mERROR: Operation cancelled by user[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import os
from sacrebleu import corpus_bleu

In [None]:
data = load_dataset("CohleM/english-to-nepali")
eng_data = data['train']['en']
nep_data = data['train']['ne']

NameError: name 'load_dataset' is not defined

In [None]:
with open("eng_tokenizer_50k.pkl", "rb") as f:
    eng_tok = pickle.load(f)
with open("nep_tokenizer_50k.pkl", "rb") as f:
    nep_tok = pickle.load(f)

In [None]:
for tok in [1500, 1501, 1502]:
    eng_tok.vocab[tok] = b''
    nep_tok.vocab[tok] = b''

PAD_ID = 1500
SOS_ID = 1501
EOS_ID = 1502

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, input_data, target_data, output_data):
        self.input_data = input_data
        self.target_data = target_data
        self.output_data = output_data

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx], self.output_data[idx]

In [None]:
def pad_batch(batch):
    input_seqs, target_seqs, output_seqs = zip(*batch)
    max_len = max(max(len(seq) for seq in input_seqs + target_seqs + output_seqs))
    pad_tensor = lambda seq: seq + [PAD_ID] * (max_len - len(seq))
    return (
        torch.tensor([pad_tensor(s) for s in input_seqs]),
        torch.tensor([pad_tensor(s) for s in target_seqs]),
        torch.tensor([pad_tensor(s) for s in output_seqs])
    )

dataset = TranslationDataset(enco_eng_data, deco_nep_data_sos, deco_nep_data_eos)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=pad_batch)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, embd, heads, dropout):
        super().__init__()
        self.attn = nn.MultiheadAttention(embd, heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embd, 2048),
            nn.ReLU(),
            nn.Linear(2048, embd)
        )
        self.norm1 = nn.LayerNorm(embd)
        self.norm2 = nn.LayerNorm(embd)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.ff(x)
        return self.norm2(x + ff_out)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embd, heads, dropout):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embd, heads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(embd, heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embd, 2048),
            nn.ReLU(),
            nn.Linear(2048, embd)
        )
        self.norm1 = nn.LayerNorm(embd)
        self.norm2 = nn.LayerNorm(embd)
        self.norm3 = nn.LayerNorm(embd)

    def forward(self, x, enc_out):
        self_attn_out, _ = self.self_attn(x, x, x, attn_mask=torch.triu(torch.ones(x.size(1), x.size(1)) * float('-inf'), diagonal=1).to(x.device))
        x = self.norm1(x + self_attn_out)
        cross_attn_out, _ = self.cross_attn(x, enc_out, enc_out)
        x = self.norm2(x + cross_attn_out)
        ff_out = self.ff(x)
        return self.norm3(x + ff_out)

In [None]:
class TranslationModel(nn.Module):
    def __init__(self, vocab_size, embd, heads, layers, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embd)
        self.pe = nn.Embedding(500, embd)
        self.encoder = nn.Sequential(*[EncoderBlock(embd, heads, dropout) for _ in range(layers)])
        self.decoder = nn.ModuleList([DecoderBlock(embd, heads, dropout) for _ in range(layers)])
        self.ln = nn.LayerNorm(embd)
        self.out = nn.Linear(embd, vocab_size)

    def forward(self, src, tgt):
        seq_len = src.size(1)
        pos = torch.arange(seq_len, device=src.device).unsqueeze(0)
        src = self.embed(src) + self.pe(pos)
        enc_out = self.encoder(src)

        tgt_len = tgt.size(1)
        pos_t = torch.arange(tgt_len, device=src.device).unsqueeze(0)
        tgt = self.embed(tgt) + self.pe(pos_t)
        for layer in self.decoder:
            tgt = layer(tgt, enc_out)

        return self.out(self.ln(tgt))

In [None]:
model = TranslationModel(vocab_size=1503, embd=256, heads=8, layers=4, dropout=0.1).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)

train_loss = []
for epoch in range(5):
    model.train()
    total_loss = 0
    for src, tgt, out in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        src, tgt, out = src.cuda(), tgt.cuda(), out.cuda()
        logits = model(src, tgt)
        loss = loss_fn(logits.view(-1, logits.size(-1)), out.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg = total_loss / len(train_loader)
    train_loss.append(avg)
    print(f"Epoch {epoch+1}, Loss: {avg:.4f}")

    # Save model each epoch
    torch.save(model.state_dict(), f"transformer_epoch{epoch+1}.pth")

In [None]:
plt.plot(train_loss, marker='o')
plt.title("Transformer NMT Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()


In [None]:
def translate(sentence, max_len=30):
    model.eval()
    with torch.no_grad():
        input_ids = torch.tensor([eng_tok.encode(sentence)]).to('cuda')
        pos = torch.arange(input_ids.size(1), device='cuda').unsqueeze(0)
        enc_out = model.encoder(model.embed(input_ids) + model.pe(pos))

        tgt_ids = torch.tensor([[SOS_ID]]).to('cuda')
        for _ in range(max_len):
            pos_t = torch.arange(tgt_ids.size(1), device='cuda').unsqueeze(0)
            x = model.embed(tgt_ids) + model.pe(pos_t)
            for layer in model.decoder:
                x = layer(x, enc_out)
            output = model.out(model.ln(x))
            next_token = output[:, -1, :].argmax(dim=-1)
            tgt_ids = torch.cat([tgt_ids, next_token.unsqueeze(1)], dim=1)
            if next_token.item() == EOS_ID:
                break
        return nep_tok.decode(tgt_ids[0, 1:-1].tolist())

In [None]:
preds = []
refs = []
for i in range(50):
    pred = translate(eng_data[i])
    ref = nep_data[i]
    preds.append(pred)
    refs.append([ref])
print("BLEU Score:", corpus_bleu(preds, refs).score)