In [1]:
import os
import torch
import torchaudio
import pandas as pd
from datasets import load_dataset
from seamless_communication.inference import Translator
from sacrebleu import corpus_bleu, corpus_chrf
from jiwer import wer
from pprint import pprint
import warnings
warnings.filterwarnings("ignore")

In [2]:

# -------------------------------
# Setup
# -------------------------------
model_name = "seamlessM4T_v2_large"
vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"

translator = Translator(
    model_name,
    vocoder_name,
    device=torch.device("cuda"),
    dtype=torch.float16,
)

Using the cached checkpoint of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached tokenizer of seamlessM4T_v2_large. Set `force` to `True` to download again.
Using the cached checkpoint of vocoder_v2. Set `force` to `True` to download again.


In [3]:

# Language mapping: SeamlessM4T → FLEURS
lang_map_full = {
    "tel": "te_in",   # Telugu
    "urd": "ur_pk",   # Urdu
}
lang_map_partial = {
    "tam": "ta_in",   # Tamil
    "ory": "or_in",   # Odia
}

fleurs_src_lang = "hi_in"   # Hindi (FLEURS)
sm4t_src_lang = "hin"       # Hindi (SM4T)

# Output folder
OUTPUT_DIR = "./translation_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [4]:
def resample_to_16k(audio, orig_sr):
    """Resample audio to 16kHz."""
    return torchaudio.transforms.Resample(orig_sr, 16000)(torch.tensor(audio)).numpy()

In [5]:
def save_dataframe(df, lang_code):
    """Save DataFrame to CSV."""
    out_path = os.path.join(OUTPUT_DIR, f"{lang_code}_results.csv")
    df.to_csv(out_path, index=False, encoding="utf-8")
    print(f"💾 Saved {out_path}")

In [6]:
def run_translation_for_language(sm4t_tgt_lang, fleurs_tgt_lang, full_tasks=True):
    """
    Run translation pipeline for one language.
    full_tasks=True → Run all 4 tasks
    full_tasks=False → Run only S2TT, T2TT
    """
    print("\n" + "="*60)
    print(f"🔹 Processing Target Language: {sm4t_tgt_lang.upper()} ({fleurs_tgt_lang})")
    print("="*60)

    # Load datasets
    src_dataset = load_dataset("google/fleurs", fleurs_src_lang, split="test")
    tgt_dataset = load_dataset("google/fleurs", fleurs_tgt_lang, split="test")

    src_by_id = {item["id"]: item for item in src_dataset}
    tgt_by_id = {item["id"]: item for item in tgt_dataset}
    common_ids = sorted(set(src_by_id.keys()) & set(tgt_by_id.keys()))

    print(f"Found {len(common_ids)} parallel sentences")

    references, hypotheses_s2tt, hypotheses_t2tt = [], [], []
    predicted_s2s, predicted_t2s = [], []
    source_texts = []

    for sentence_id in common_ids:
        src = src_by_id[sentence_id]
        tgt = tgt_by_id[sentence_id]

        src_audio = src["audio"]["array"]
        src_sr = src["audio"]["sampling_rate"]
        src_text = src["transcription"]
        tgt_text = tgt["transcription"]

        references.append([tgt_text])
        source_texts.append(src_text)

        if src_sr != 16000:
            src_audio = resample_to_16k(src_audio, src_sr)
        base_dir = "/scratch/aj/Bhavna/bhav_venv_311/Project/input_audios"

            # Create language-specific folder
        lang_dir = os.path.join(base_dir, sm4t_tgt_lang)
        os.makedirs(lang_dir, exist_ok=True)

        audio_path = os.path.join(lang_dir, f"input_{sm4t_tgt_lang}_{sentence_id}.wav")
        #audio_path = f"/scratch/aj/Bhavna/bhav_venv_311/Project/input_audios/input{sm4t_tgt_lang}_{sentence_id}.wav"
        torchaudio.save(audio_path, torch.tensor(src_audio).unsqueeze(0), 16000)

        # --- S2TT ---
        s2tt_out, _ = translator.predict(
            input=audio_path, task_str="s2tt", tgt_lang=sm4t_tgt_lang
        )
        hypotheses_s2tt.append(s2tt_out[0])

        # --- T2TT ---
        t2tt_out, _ = translator.predict(
            input=src_text, task_str="t2tt", src_lang=sm4t_src_lang, tgt_lang=sm4t_tgt_lang
        )
        hypotheses_t2tt.append(t2tt_out[0])

        if full_tasks:
            # --- S2ST + ASR ---
            _, s2s_audio_out = translator.predict(
                input=audio_path, task_str="s2st", tgt_lang=sm4t_tgt_lang
            )
            # Define base directory
            base_dir = "/scratch/aj/Bhavna/bhav_venv_311/Project/s2s_outputs"

            # Create language-specific folder
            lang_dir = os.path.join(base_dir, sm4t_tgt_lang)
            os.makedirs(lang_dir, exist_ok=True)

            # Final file path
            s2s_path = os.path.join(lang_dir, f"s2s_{sm4t_tgt_lang}_{sentence_id}.wav")

            #s2s_path = f"/scratch/aj/Bhavna/bhav_venv_311/Project/s2s_outputs/s2s{sm4t_tgt_lang}_{sentence_id}.wav"
            torchaudio.save(
                s2s_path,
                s2s_audio_out.audio_wavs[0][0].to(torch.float32).cpu(),
                s2s_audio_out.sample_rate,
            )
            s2s_asr_out, _ = translator.predict(
                input=s2s_path, task_str="asr", tgt_lang=sm4t_tgt_lang
            )
            predicted_s2s.append(s2s_asr_out[0])

            # --- T2ST + ASR ---
            _, t2s_audio_out = translator.predict(
                input=src_text, task_str="t2st", src_lang=sm4t_src_lang, tgt_lang=sm4t_tgt_lang
            )
            # Define base directory
            base_dir = "/scratch/aj/Bhavna/bhav_venv_311/Project/t2s_outputs"

            # Create language-specific folder
            lang_dir = os.path.join(base_dir, sm4t_tgt_lang)
            os.makedirs(lang_dir, exist_ok=True)

            # Final file path
            t2s_path = os.path.join(lang_dir, f"t2s_{sm4t_tgt_lang}_{sentence_id}.wav")

            #t2s_path = f"/scratch/aj/Bhavna/bhav_venv_311/Project/t2s_outputs/t2s{sm4t_tgt_lang}_{sentence_id}.wav"
            torchaudio.save(
                t2s_path,
                t2s_audio_out.audio_wavs[0][0].to(torch.float32).cpu(),
                t2s_audio_out.sample_rate,
            )
            t2s_asr_out, _ = translator.predict(
                input=t2s_path, task_str="asr", tgt_lang=sm4t_tgt_lang
            )
            predicted_t2s.append(t2s_asr_out[0])

    # Build dataframe
    data = {
        "source_text": source_texts,
        "reference_text": [r[0] for r in references],
        "S2TT_prediction": hypotheses_s2tt,
        "T2TT_prediction": hypotheses_t2tt,
    }
    if full_tasks:
        data["S2ST_ASR"] = predicted_s2s
        data["T2ST_ASR"] = predicted_t2s

    df = pd.DataFrame(data)
    save_dataframe(df, sm4t_tgt_lang)

    return references, hypotheses_s2tt, hypotheses_t2tt, predicted_s2s, predicted_t2s

In [7]:
from sacrebleu import corpus_bleu, corpus_chrf
from jiwer import wer

def compute_metrics(lang_code, references, hypotheses_s2tt, hypotheses_t2tt, predicted_s2s, predicted_t2s):
    # Normalize hyps
    hypotheses_s2tt = [" ".join(h) if isinstance(h, list) else str(h) for h in hypotheses_s2tt]
    hypotheses_t2tt = [" ".join(h) if isinstance(h, list) else str(h) for h in hypotheses_t2tt]
    predicted_s2s   = [" ".join(h) if isinstance(h, list) else str(h) for h in predicted_s2s]
    predicted_t2s   = [" ".join(h) if isinstance(h, list) else str(h) for h in predicted_t2s]

    # Normalize refs
    references_norm = []
    for refset in references:
        references_norm.append([" ".join(r) if isinstance(r, list) else str(r) for r in refset])
    multi_references = list(zip(*references_norm))  # multiple references per sentence

    metrics = {}
    # Text-to-text & speech-to-text
    metrics["S2TT_SacreBLEU"] = corpus_bleu(hypotheses_s2tt, multi_references).score
    metrics["T2TT_chrF2++"]   = corpus_chrf(hypotheses_t2tt, multi_references).score
    metrics["S2TT_WER"]       = sum(wer(ref[0], hyp) for ref, hyp in zip(references_norm, hypotheses_s2tt)) / len(references_norm)
    metrics["T2TT_WER"]       = sum(wer(ref[0], hyp) for ref, hyp in zip(references_norm, hypotheses_t2tt)) / len(references_norm)

    # Speech-to-speech (decoded to text for scoring)
    if predicted_s2s:
        metrics["S2ST_ASR_WER"]  = sum(wer(ref[0], hyp) for ref, hyp in zip(references_norm, predicted_s2s)) / len(references_norm)
        metrics["S2ST_ASR_BLEU"] = corpus_bleu(predicted_s2s, multi_references).score

    if predicted_t2s:
        metrics["T2ST_ASR_WER"]  = sum(wer(ref[0], hyp) for ref, hyp in zip(references_norm, predicted_t2s)) / len(references_norm)
        metrics["T2ST_ASR_BLEU"] = corpus_bleu(predicted_t2s, multi_references).score
        
    return metrics


In [8]:
# -------------------------------
# Main Driver
# -------------------------------
all_metrics = {}

# Full tasks: Telugu + Urdu
for sm4t_tgt, fleurs_tgt in lang_map_full.items():
    refs, s2tt, t2tt, s2s, t2s = run_translation_for_language(sm4t_tgt, fleurs_tgt, full_tasks=True)
    all_metrics[sm4t_tgt] = compute_metrics(sm4t_tgt, refs, s2tt, t2tt, s2s, t2s)
    print(f"🔹 {sm4t_tgt.upper()} Metrics: {all_metrics[sm4t_tgt]}")
# Partial tasks: Tamil + Odia
for sm4t_tgt, fleurs_tgt in lang_map_partial.items():
    refs, s2tt, t2tt, _, _ = run_translation_for_language(sm4t_tgt, fleurs_tgt, full_tasks=False)
    all_metrics[sm4t_tgt] = compute_metrics(sm4t_tgt, refs, s2tt, t2tt, [], [])
    print(f"🔹 {sm4t_tgt.upper()} Metrics: {all_metrics[sm4t_tgt]}")

print("\n✅ All processing complete.")


🔹 Processing Target Language: TEL (te_in)
Found 228 parallel sentences
💾 Saved ./translation_results/tel_results.csv
🔹 TEL Metrics: {'S2TT_SacreBLEU': 6.791687079616478, 'T2TT_chrF2++': 49.22213115874377, 'S2TT_WER': 0.9214235330398954, 'T2TT_WER': 0.85229959325962, 'S2ST_ASR_WER': 0.9207014212978243, 'S2ST_ASR_BLEU': 6.703186032040358, 'T2ST_ASR_WER': 0.8816193817330296, 'T2ST_ASR_BLEU': 9.072877840222192}

🔹 Processing Target Language: URD (ur_pk)
Found 176 parallel sentences
💾 Saved ./translation_results/urd_results.csv
🔹 URD Metrics: {'S2TT_SacreBLEU': 14.627013884604041, 'T2TT_chrF2++': 44.07527441374375, 'S2TT_WER': 0.7461084350137203, 'T2TT_WER': 0.7154703562371476, 'S2ST_ASR_WER': 0.7457257921461948, 'S2ST_ASR_BLEU': 14.802831740006779, 'T2ST_ASR_WER': 0.7231635694724213, 'T2ST_ASR_BLEU': 16.53967566437534}

🔹 Processing Target Language: TAM (ta_in)
Found 253 parallel sentences
💾 Saved ./translation_results/tam_results.csv
🔹 TAM Metrics: {'S2TT_SacreBLEU': 4.273384897001094, '