In [None]:
import torch
from datasets import load_dataset
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm

dataset = load_dataset("SKNahin/bengali-transliteration-data")

data_split = dataset['train'].train_test_split(test_size=0.2)
train_data = data_split['train']
val_data = data_split['test']

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")

def tokenize_function(batch):
    inputs = tokenizer(batch['rm'], padding="max_length", truncation=True, max_length=128)
    targets = tokenizer(batch['bn'], padding="max_length", truncation=True, max_length=128)
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": targets["input_ids"],
    }

train_dataset = train_data.map(tokenize_function, batched=True)
val_dataset = val_data.map(tokenize_function, batched=True)

train_dataset = train_dataset.remove_columns(['bn', 'rm'])
val_dataset = val_dataset.remove_columns(['bn', 'rm'])

train_dataset.set_format("torch")
val_dataset.set_format("torch")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small").to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

epochs = 5
for epoch in range(epochs):
    model.train()
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += outputs.loss.item()

    print(f"Validation Loss after Epoch {epoch}: {val_loss / len(val_loader)}")

model.save_pretrained("./banglish-to-bangla-model")
tokenizer.save_pretrained("./banglish-to-bangla-model")


In [None]:
from nltk.translate.bleu_score import corpus_bleu

def evaluate_model(model, val_loader, device, tokenizer):
    model.eval()
    
    all_predictions = []
    all_references = []
    
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

            predictions = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=128)
            
            predicted_texts = tokenizer.batch_decode(predictions, skip_special_tokens=True)
            reference_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)

            all_predictions.extend(predicted_texts)
            all_references.extend(reference_texts)

    avg_loss = val_loss / len(val_loader)
    
    references = [[ref.split()] for ref in all_references]
    predictions = [pred.split() for pred in all_predictions] 
    
    bleu_score = corpus_bleu(references, predictions)

    return avg_loss, bleu_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

validation_loss, bleu = evaluate_model(model, val_loader, device, tokenizer)

print(f"Validation Loss: {validation_loss}")
print(f"BLEU Score: {bleu}")
