In [None]:
#@title Irula-Malayalam Model Evaluation
MAX_SAMPLES = 100 # @param {type:"number"}
CHECKPOINT_PATH = "checkpoints/ft_gs_m4tM.pt" # @param {type:"string"}
EVAL_MANIFEST = "manifests/valid_manifest.json" # @param {type:"string"}
CSV_PATH = "lama.csv" # @param {type:"string"}

import os
import torch
import json
import logging
from tqdm import tqdm
import pandas as pd
import librosa
from typing import Tuple, Iterable, Dict, Any
from seamless_communication.models.unity import UnitYModel
from seamless_communication.inference import Translator
from jiwer import wer, cer

def load_text_data(csv_path: str) -> Dict[str, Dict[str, str]]:
    """Load text data from CSV file"""
    print("Loading text data from CSV...")
    df = pd.read_csv(csv_path)
    
    id_to_text = {}
    for _, row in df.iterrows():
        file_id = str(row['id of speech and text'])
        
        # Handle NaN values
        irula_val = row['Irula text']
        malayalam_val = row['Malayalam text']
        
        irula_text = str(irula_val) if not pd.isna(irula_val) and str(irula_val).strip() != '' else ""
        malayalam_text = str(malayalam_val) if not pd.isna(malayalam_val) and str(malayalam_val).strip() != '' else ""
        
        id_to_text[file_id] = {
            'irula': irula_text,
            'malayalam': malayalam_text
        }
    
    print(f"Loaded text data for {len(id_to_text)} entries")
    return id_to_text

def _iterate_irula_dataset(manifest_path: str, text_data: Dict) -> Iterable[Tuple[torch.Tensor, str, str]]:
    """Iterate over Irula-Malayalam dataset from manifest"""
    print(f"Loading evaluation data from {manifest_path}...")
    
    with open(manifest_path, 'r', encoding='utf-8') as f:
        count = 0
        for line in f:
            if count >= MAX_SAMPLES:
                break
                
            sample = json.loads(line.strip())
            
            # Get source audio path (Malayalam)
            src_audio_path = sample['source']['audio_local_path']
            file_id = sample['source']['id']
            
            try:
                # Load source audio (Malayalam)
                audio, sr = librosa.load(src_audio_path, sr=16000, mono=True)
                assert sr == 16000, f"Audio must be 16kHz, got {sr}Hz"
                audio_tensor = torch.from_numpy(audio)
                
                # Get texts
                malayalam_text = sample['source'].get('text', '')
                irula_text = sample['target'].get('text', '')
                
                # Fallback to CSV data if not in manifest
                if not malayalam_text or not irula_text:
                    if file_id in text_data:
                        malayalam_text = text_data[file_id]['malayalam']
                        irula_text = text_data[file_id]['irula']
                
                if malayalam_text and irula_text:
                    yield audio_tensor, malayalam_text, irula_text
                    count += 1
                    
            except Exception as e:
                print(f"Error loading audio {src_audio_path}: {e}")
                continue

def eval_speech_to_text(translator: Translator, text_data: Dict) -> Tuple[float, float]:
    """Evaluate Speech-to-Text (Malayalam speech -> Irula text)"""
    references = []
    predictions = []
    
    ds = _iterate_irula_dataset(EVAL_MANIFEST, text_data)
    
    for idx, (wav, malayalam_text, irula_text) in tqdm(enumerate(ds), desc="S2T Evaluation"):
        if not irula_text.strip():
            continue
            
        references.append(irula_text)
        
        try:
            # Malayalam speech -> Irula text
            prediction = translator.predict(
                input=wav,
                task_str="s2tt",  # Speech-to-Text
                src_lang="mal",   # Malayalam
                tgt_lang="mal",   # Irula (you might need to use "und" if "iru" is not supported)
            )
            
            if prediction and len(prediction) > 0:
                predicted_text = str(prediction[0][0])
            else:
                predicted_text = ""
                
            predictions.append(predicted_text)
            
        except Exception as e:
            print(f"Error in S2T prediction: {e}")
            predictions.append("")
    
    if predictions and references:
        wer_score = wer(reference=references, hypothesis=predictions)
        cer_score = cer(reference=references, hypothesis=predictions)
        return wer_score, cer_score
    else:
        return 1.0, 1.0

def eval_speech_to_speech(translator: Translator, text_data: Dict) -> Tuple[float, float]:
    """Evaluate Speech-to-Speech (Malayalam speech -> Irula speech, compare transcripts)"""
    references = []
    predictions = []
    
    ds = _iterate_irula_dataset(EVAL_MANIFEST, text_data)
    
    for idx, (wav, malayalam_text, irula_text) in tqdm(enumerate(ds), desc="S2S Evaluation"):
        if not irula_text.strip():
            continue
            
        references.append(irula_text)
        
        try:
            # Malayalam speech -> Irula speech (we'll evaluate the text content)
            prediction = translator.predict(
                input=wav,
                task_str="s2tt",  # Speech-to-Speech
                src_lang="mal",   # Malayalam
                tgt_lang="mal",   # Irula
            )
            
            # For S2ST, we need to get the text from the generated speech
            # This is a simplified version - you might need to use ASR on the output
            if prediction and len(prediction) > 0:
                predicted_text = str(prediction[0][0])  # This might need adjustment
            else:
                predicted_text = ""
                
            predictions.append(predicted_text)
            
        except Exception as e:
            print(f"Error in S2S prediction: {e}")
            predictions.append("")
    
    if predictions and references:
        wer_score = wer(reference=references, hypothesis=predictions)
        cer_score = cer(reference=references, hypothesis=predictions)
        return wer_score, cer_score
    else:
        return 1.0, 1.0

def load_checkpoint(model: UnitYModel, path: str, device="cuda") -> None:
    """Load checkpoint with better error handling"""
    print(f"Loading checkpoint from {path}...")
    
    try:
        state_dict = torch.load(path, map_location=device)["model"]

        def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
            return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}

        model.speech_encoder_frontend.load_state_dict(_select_keys(state_dict, "model.speech_encoder_frontend."))
        model.speech_encoder.load_state_dict(_select_keys(state_dict, "model.speech_encoder."))

        if model.text_decoder_frontend is not None:
            model.text_decoder_frontend.load_state_dict(_select_keys(state_dict, "model.text_decoder_frontend."))

        if model.text_decoder is not None:
            model.text_decoder.load_state_dict(_select_keys(state_dict, "model.text_decoder."))
        
        print("Checkpoint loaded successfully!")
        
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise

# Check if required files exist
if not os.path.exists(EVAL_MANIFEST):
    print(f"Error: Evaluation manifest not found at {EVAL_MANIFEST}")
    print("Please run your manifest creation script first.")
else:
    # Load text data
    text_data = load_text_data(CSV_PATH)
    
    # Initialize translator
    print("Initializing translator...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    translator = Translator(
        model_name_or_card="seamlessM4T_v2_large",
        vocoder_name_or_card=None,
        device=device
    )
    
    # Evaluate base model
    print("\n" + "="*50)
    print("EVALUATING BASE MODEL")
    print("="*50)
    
    try:
        base_s2t_wer, base_s2t_cer = eval_speech_to_text(translator, text_data)
        print(f"Base Model S2T - WER: {base_s2t_wer:.4f}, CER: {base_s2t_cer:.4f}")
    except Exception as e:
        print(f"Error in base S2T evaluation: {e}")
        base_s2t_wer, base_s2t_cer = 1.0, 1.0
    
    try:
        base_s2s_wer, base_s2s_cer = eval_speech_to_speech(translator, text_data)
        print(f"Base Model S2S - WER: {base_s2s_wer:.4f}, CER: {base_s2s_cer:.4f}")
    except Exception as e:
        print(f"Error in base S2S evaluation: {e}")
        base_s2s_wer, base_s2s_cer = 1.0, 1.0
    
    # Load and evaluate fine-tuned model
    if os.path.exists(CHECKPOINT_PATH):
        print("\n" + "="*50)
        print("EVALUATING FINE-TUNED MODEL")
        print("="*50)
        
        try:
            load_checkpoint(translator.model, CHECKPOINT_PATH, device)
            
            tuned_s2t_wer, tuned_s2t_cer = eval_speech_to_text(translator, text_data)
            print(f"Fine-tuned S2T - WER: {tuned_s2t_wer:.4f}, CER: {tuned_s2t_cer:.4f}")
            
            tuned_s2s_wer, tuned_s2s_cer = eval_speech_to_speech(translator, text_data)
            print(f"Fine-tuned S2S - WER: {tuned_s2s_wer:.4f}, CER: {tuned_s2s_cer:.4f}")
            
            # Calculate improvements
            print("\n" + "="*50)
            print("IMPROVEMENT SUMMARY")
            print("="*50)
            print(f"S2T WER Improvement: {base_s2t_wer - tuned_s2t_wer:.4f}")
            print(f"S2T CER Improvement: {base_s2t_cer - tuned_s2t_cer:.4f}")
            print(f"S2S WER Improvement: {base_s2s_wer - tuned_s2s_wer:.4f}")
            print(f"S2S CER Improvement: {base_s2s_cer - tuned_s2s_cer:.4f}")
            
        except Exception as e:
            print(f"Error with fine-tuned model: {e}")
    else:
        print(f"\nCheckpoint not found at {CHECKPOINT_PATH}")
        print("Only base model evaluation performed.")
    
    # Clear CUDA memory
    del translator
    torch.cuda.empty_cache()
    
    print("\nEvaluation completed!")