<a href="https://colab.research.google.com/github/bodadineshreddy/indictrans2/blob/main/mBART.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
# !git clone https://github.com/AI4Bharat/IndicTrans2.git
# %cd /content/IndicTransToolkit
# !git clone https://github.com/VarunGumma/IndicTransToolkit.git
# !pip install git+https://github.com/VarunGumma/IndicTransToolkit.git
# # !python3 -m pip install --editable ./
# !python3 -c "import nltk; nltk.download('punkt')"

!pip install transformers datasets torch sentencepiece sacrebleu bitsandbytes scipy accelerate
!pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer

In [5]:
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, TrainingArguments, Trainer
from datasets import load_dataset, Dataset, concatenate_datasets

# ==============================
# Configuration
# ==============================

BATCH_SIZE = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "facebook/mbart-large-50-many-to-many-mmt"

# Load tokenizer and model
tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)
model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)

# ==============================
# Define Language Pairs (All Pairs from 4 Languages)
# ==============================
lang_pairs = [
    ("eng_Latn", "tel_Telu"),
    ("eng_Latn", "tam_Taml"),
    ("eng_Latn", "hin_Deva"),
    ("hin_Deva", "tam_Taml"),
    ("hin_Deva", "tel_Telu"),
    ("tam_Taml", "tel_Telu"),
]

# Convert to mBART language codes
nllb_to_mbart = {
    "eng_Latn": "en_XX",
    "tel_Telu": "te_IN",
    "tam_Taml": "ta_IN",
    "hin_Deva": "hi_IN"
}

# ==============================
# Convert Streaming Dataset to In-Memory Before Preprocessing
# ==============================

def convert_to_inmemory(streaming_dataset, num_samples=2000):
    """
    Converts an entire streaming dataset into an in-memory dataset.
    Handles cases where dataset has fewer samples than expected.
    """
    dataset_iter = iter(streaming_dataset)  # Create an iterator
    batch_samples = []

    for _ in range(num_samples):
        try:
            batch_samples.append(next(dataset_iter))  # Fetch samples
        except StopIteration:
            break  # Stop if dataset has fewer samples than expected

    return Dataset.from_list(batch_samples)  # Convert to Hugging Face in-memory dataset


# ==============================
# Preprocessing Function (Without .map())
# ==============================

def preprocess_function(dataset):
    """
    Tokenizes dataset examples for all available translation pairs dynamically.
    Returns a Hugging Face Dataset instead of using .map().
    """
    processed_examples = []

    for example in dataset:
        translation_data = example["translation"]
        available_langs = list(translation_data.keys())

        if len(available_langs) < 2:
            continue  # Skip if fewer than 2 languages are available

        # Generate all possible translation pairs
        for src_lang_code in available_langs:
            for tgt_lang_code in available_langs:
                if src_lang_code == tgt_lang_code:
                    continue  # Skip same-language pairs

                src_text = translation_data[src_lang_code]
                tgt_text = translation_data[tgt_lang_code]

                if not src_text or not tgt_text:
                    continue  # Skip empty translations

                # Convert to mBART language codes
                src_mbart = nllb_to_mbart.get(src_lang_code, None)
                tgt_mbart = nllb_to_mbart.get(tgt_lang_code, None)

                if not src_mbart or not tgt_mbart:
                    continue  # Skip if language codes are missing

                # ====== Tokenize Forward Pair ======
                tokenizer.src_lang = src_mbart
                tokenizer.tgt_lang = tgt_mbart
                forward_model_inputs = tokenizer(src_text, truncation=True, padding="max_length", max_length=128)
                forward_labels = tokenizer(text_target=tgt_text, truncation=True, padding="max_length", max_length=128)
                forward_model_inputs["labels"] = forward_labels["input_ids"]

                processed_examples.append(forward_model_inputs)

                # ====== Tokenize Reversed Pair ======
                tokenizer.src_lang = tgt_mbart
                tokenizer.tgt_lang = src_mbart
                reversed_model_inputs = tokenizer(tgt_text, truncation=True, padding="max_length", max_length=128)
                reversed_labels = tokenizer(text_target=src_text, truncation=True, padding="max_length", max_length=128)
                reversed_model_inputs["labels"] = reversed_labels["input_ids"]

                processed_examples.append(reversed_model_inputs)

    return Dataset.from_list(processed_examples)  # Return a fully processed dataset


# ==============================
# Apply Preprocessing to In-Memory Datasets
# ==============================

samples_per_pair = 10000  # Reduce dataset size for efficiency
processed_datasets = []

for pair in lang_pairs:
    print(f"Loading dataset for {pair}")

    # Load dataset in streaming mode
    dataset = load_dataset("allenai/nllb", f"{pair[0]}-{pair[1]}", split="train", streaming=True, trust_remote_code=True)

    # Convert to in-memory dataset before preprocessing
    dataset = convert_to_inmemory(dataset, samples_per_pair)

    # Apply preprocessing without .map()
    processed_dataset = preprocess_function(dataset)

    processed_datasets.append(processed_dataset)

# ==============================
# Concatenate All Processed Datasets
# ==============================

combined_dataset = concatenate_datasets(processed_datasets)

# ==============================
# Fine-Tuning Step
# ==============================

# Check if bf16 is supported by GPU
bf16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

training_args = TrainingArguments(
    output_dir="./fine_tuned_mbart",
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=2,
    save_steps=500,
    logging_steps=1000,
    evaluation_strategy="no",
    report_to=None,  # Disable external logging
    bf16=bf16_supported if bf16_supported else False,  # Use bf16 only if supported
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=combined_dataset,
    tokenizer=tokenizer,
)

# Fine-tune the model
trainer.train()

# Save fine-tuned model & tokenizer
trainer.save_model("./fine_tuned_mbart")
tokenizer.save_pretrained("./fine_tuned_mbart")


Loading dataset for ('eng_Latn', 'tel_Telu')


Repo card metadata block was not found. Setting CardData to empty.


Loading dataset for ('eng_Latn', 'tam_Taml')


Repo card metadata block was not found. Setting CardData to empty.


Loading dataset for ('eng_Latn', 'hin_Deva')


Repo card metadata block was not found. Setting CardData to empty.


Loading dataset for ('hin_Deva', 'tam_Taml')


Repo card metadata block was not found. Setting CardData to empty.


Loading dataset for ('hin_Deva', 'tel_Telu')


Repo card metadata block was not found. Setting CardData to empty.


Loading dataset for ('tam_Taml', 'tel_Telu')


Repo card metadata block was not found. Setting CardData to empty.
  trainer = Trainer(


Step,Training Loss
1000,0.5224
2000,0.2798
3000,0.2565
4000,0.2414
5000,0.2259
6000,0.2154
7000,0.2042
8000,0.1962
9000,0.1847
10000,0.1805




('./fine_tuned_mbart/tokenizer_config.json',
 './fine_tuned_mbart/special_tokens_map.json',
 './fine_tuned_mbart/sentencepiece.bpe.model',
 './fine_tuned_mbart/added_tokens.json',
 './fine_tuned_mbart/tokenizer.json')

In [7]:
import torch
import sacrebleu
from sacrebleu.metrics import TER
from nltk.translate.meteor_score import meteor_score
import nltk
nltk.download('wordnet')

# ==============================
# Translation Function
# ==============================

def translate_text(input_sentences, model, tokenizer, src_lang, tgt_lang, device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Translates a batch of input sentences from src_lang to tgt_lang using the given model and tokenizer.

    :param input_sentences: List of input sentences.
    :param model: Pre-trained translation model.
    :param tokenizer: Tokenizer for the model.
    :param src_lang: Source language code (e.g., 'en_XX').
    :param tgt_lang: Target language code (e.g., 'te_IN').
    :param device: Device to run the model on ('cuda' or 'cpu').

    :return: List of translated sentences.
    """
    tokenizer.src_lang = src_lang  # Set source language
    inputs = tokenizer(
        input_sentences,
        truncation=True,
        padding="longest",
        max_length=256,  # Ensures consistency
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        generated_tokens = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],  # Target language for mBART
            max_length=256,
            num_beams=5,
        )

    # Decode generated translations
    translated_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    return translated_texts


# ==============================
# Translation + Evaluation Function
# ==============================

def batch_translate_with_evaluation(input_sentences, reference_sentences, model, tokenizer, src_lang, tgt_lang):
    """
    Translates a batch of input sentences and evaluates them using BLEU, METEOR, and TER scores.
    """
    # Get translated sentences using the same function
    generated_translations = translate_text(input_sentences, model, tokenizer, src_lang, tgt_lang)

    # ==============================
    # Evaluation Metrics
    # ==============================

    # Compute BLEU Score
    bleu_score = sacrebleu.corpus_bleu(generated_translations, [reference_sentences]).score

    # Compute METEOR Score (Fix: Tokenized Inputs)
    meteor_scores = [
        meteor_score([ref.split()], pred.split()) for ref, pred in zip(reference_sentences, generated_translations)
    ]
    avg_meteor_score = sum(meteor_scores) / len(meteor_scores)

    # Compute TER Score (Translation Edit Rate) using SacreBLEU's TER
    ter_metric = TER()
    ter_scores = [
        ter_metric.sentence_score(pred, [ref]).score for ref, pred in zip(reference_sentences, generated_translations)
    ]
    avg_ter_score = sum(ter_scores) / len(ter_scores)

    # Print Scores
    print(f"- {src_lang} → {tgt_lang}")
    print(f"- BLEU Score: {bleu_score:.2f}")
    print(f"- METEOR Score: {avg_meteor_score:.2f}")
    print(f"- TER Score: {avg_ter_score:.2f}")

    # Print Translations
    for src, ref, pred in zip(input_sentences, reference_sentences, generated_translations):
        print(f"- Source ({src_lang}): {src}")
        print(f"- Reference ({tgt_lang}): {ref}")
        print(f"- Model Translation ({tgt_lang}): {pred}")
    print("-" * 50)

    return generated_translations


# ==============================
# Test Translation + Evaluation for All Pairs
# ==============================

test_sentences = {
    "en_XX": [
        "Hello, how are you?",
        "This is a beautiful day.",
        "I love learning new languages."
    ],
    "te_IN": [
        "హలో, మీరు ఎలా ఉన్నారు?",
        "ఇది ఒక అందమైన రోజు.",
        "నేను కొత్త భాషలు నేర్చుకోవాలని ఇష్టపడుతున్నాను."
    ],
    "ta_IN": [
        "வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?",
        "இது ஒரு அழகான நாள்.",
        "எனக்கு புதிய மொழிகளை கற்க விருப்பம்."
    ],
    "hi_IN": [
        "नमस्ते, आप कैसे हैं?",
        "यह एक सुंदर दिन है।",
        "मुझे नई भाषाएँ सीखना पसंद है।"
    ]
}

# Define language pairs to test
test_pairs = [
    ("en_XX", "te_IN"),
    ("en_XX", "ta_IN"),
    ("en_XX", "hi_IN"),
    ("te_IN", "en_XX"),
    ("te_IN", "ta_IN"),
    ("ta_IN", "hi_IN"),
    ("hi_IN", "te_IN"),
]
print("-" * 50)
# Perform translation and evaluation for each pair
for src_lang, tgt_lang in test_pairs:
    print(f"- Translating from {src_lang} → {tgt_lang}")

    # Get reference translations
    reference_translations = test_sentences.get(tgt_lang, [""] * len(test_sentences[src_lang]))

    # Run translation & evaluation
    batch_translate_with_evaluation(test_sentences[src_lang], reference_translations, model, tokenizer, src_lang, tgt_lang)


[nltk_data] Downloading package wordnet to /root/nltk_data...


--------------------------------------------------
- Translating from en_XX → te_IN
- en_XX → te_IN
- BLEU Score: 25.58
- METEOR Score: 0.36
- TER Score: 51.67
- Source (en_XX): Hello, how are you?
- Reference (te_IN): హలో, మీరు ఎలా ఉన్నారు?
- Model Translation (te_IN): హలో, ఎలా ఉన్నావు?
- Source (en_XX): This is a beautiful day.
- Reference (te_IN): ఇది ఒక అందమైన రోజు.
- Model Translation (te_IN): ఈ ఒక అందమైన రోజు.
- Source (en_XX): I love learning new languages.
- Reference (te_IN): నేను కొత్త భాషలు నేర్చుకోవాలని ఇష్టపడుతున్నాను.
- Model Translation (te_IN): నేను క్రొత్త భాషలను నేర్చుకోవడం ప్రేమిస్తున్నాను.
--------------------------------------------------
- Translating from en_XX → ta_IN
- en_XX → ta_IN
- BLEU Score: 32.72
- METEOR Score: 0.41
- TER Score: 51.67
- Source (en_XX): Hello, how are you?
- Reference (ta_IN): வணக்கம், நீங்கள் எப்படி இருக்கிறீர்கள்?
- Model Translation (ta_IN): ஹாய், நீ எப்படி இருக்கிறாய்?
- Source (en_XX): This is a beautiful day.
- Reference (ta_IN): இத