In [None]:
!pip install -q transformers datasets evaluate huggingface_hub sentencepiece

In [None]:
import os
import evaluate
import numpy as np
from datasets import load_dataset, Dataset
from transformers import (MBartForConditionalGeneration, 
                          MBartTokenizer, 
                          TrainingArguments, 
                          Trainer)
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [None]:
# 1. Load the dataset
dataset = load_dataset("SKNahin/bengali-transliteration-data")

# The dataset might have only a "train" split. Let's split it into train and validation.
# If the dataset already has a train-validation split, adjust accordingly.
raw_train_dataset = dataset["train"]
split_dataset = raw_train_dataset.train_test_split(
    test_size=0.2,   # 80/20 split
    seed=42
)

train_dataset = split_dataset["train"]
val_dataset   = split_dataset["test"]

print("Number of training samples:", len(train_dataset))
print("Number of validation samples:", len(val_dataset))

# Explore a sample (optional)
print("\nSample data:", train_dataset[0])

2. Data Preprocessing

In [None]:
# from transformers import MBart50Tokenizer
# from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

# model_name = "facebook/mbart-large-50"
model_name = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_name)

# tokenizer = MBart50Tokenizer.from_pretrained(model_name)
# print(tokenizer.lang_code_to_id.keys())
tokenizer.src_lang = "bn_IN"
tokenizer.tgt_lang = "bn_IN"

In [None]:
max_length = 64

def preprocess_function(examples):
    source_texts = examples["rm"]
    target_texts = examples["bn"]
    prefix = "translate Banglish to Bengali: "

    model_inputs = tokenizer(
        [prefix + text for text in source_texts], 
        max_length=max_length, 
        truncation=True, 
        padding="max_length"  # or "longest"
    )
    
    # Tokenize the target (Bangla) as labels
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_texts, 
            max_length=max_length, 
            truncation=True, 
            padding="max_length"
        )
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

# Apply the preprocessing
train_dataset = train_dataset.map(
    preprocess_function, 
    batched=True, 
    remove_columns=["rm", "bn"]
)
val_dataset = val_dataset.map(
    preprocess_function, 
    batched=True, 
    remove_columns=["rm", "bn"]
)

print("\nTokenized train sample:\n", train_dataset[0])


3. Model Load

In [None]:

# model = MBartForConditionalGeneration.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)

4. Training Setup

In [None]:
import torch

bleu = evaluate.load("bleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    
    references = [[label] for label in decoded_labels]
    results = bleu.compute(predictions=decoded_preds, references=references)
    return {"bleu": results["bleu"]}

hf_username = "torr20"   
repo_name   = "another-avro" 
hub_model_id = f"{hf_username}/{repo_name}"

training_args = TrainingArguments(
    output_dir="banglish2bangla-mbart",
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,            
    weight_decay=0.01,
    save_total_limit=2,
    hub_model_id=hub_model_id,     
    logging_steps=100,
    report_to="none",                 
    fp16=True if torch.cuda.is_available() else False 
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()