## Fine Tuining model (tuskbyte/nepali_male_v1)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from transformers import VitsModel, VitsTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from transformers.modeling_outputs import ModelOutput
import librosa 
import numpy as np
import pandas as pd
from datasets import Dataset as HFDataset
import soundfile as sf
from tqdm import tqdm
import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import logging
from datetime import datetime


  from .autonotebook import tqdm as notebook_tqdm
2025-09-23 08:41:03.981361: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-23 08:41:05.019685: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-09-23 08:41:09.509656: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
#logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
## AUDIO PROCESSING AND MEL SPECTROGRAM

def audio_to_mel_spectrogram(audio, sr=16000, n_mels=80, n_fft=1024, hop_length=256, win_length=1024):
    """
    Convert audio waveform to mel spectrogram
    Args: 
        audio (np.array or torch.Tensor) : The 1D audio waveform
        sr (int): Sampling rate of the audio in Hz. 
        n_mels (int): Number of mel frequency bands. 
        n_fft (int): Number of samples per FFT window. Determines frequency resolution.
        hop_length(int): Number of samples to step between successive FFT windows. Controls time resolution
        win_length(int): Size of each FFT window in samples. Typically equal to n_fft.

        Returns:
            mel_spectrogram (torch.Tensor): The mel spectrogram of shape (n_mels, time_frames)
    """
    #ensure audio is in a numpy array
    if torch.is_tensor(audio):
        audio = audio.numpy()

    # compute mel spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_mels=n_mels,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        fmin=0.0,
        fmax=sr / 2.0
    )


    # convert to log scale (dB)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

    return torch.FloatTensor(mel_spec_db)


In [4]:

# AUDIO PRE-PROCESSING

def preprocess_audio(audio_path, target_sr=16000, target_db=-20):
    """
    Preprocessing audio file for VITS training.
    - Normalize sampling rate to 22050 Hz (VITS standard)
    - Normalize volume to -20 db (consistent loudness, good for training)
    - Trim silence from beginning/end (focus on speech)
    """
    try:
        #loading audio
        audio, sr = librosa.load(audio_path, sr=target_sr)

        #removing silence from beginning and end
        audio, _ = librosa.effects.trim(audio, top_db=20)

        # applying target db normalization
        eps = 1e-9
        # rms = torch.sqrt(torch.mean(torch.tensor(audio)**2))
        rms = np.sqrt(np.mean(audio.astype(np.float64)**2) + eps)
        if rms > 0:
            target_rms = 10 ** (target_db/20)
            audio = audio * (target_rms / rms)

        # ensure minimun length of audio
        min_sample = int(0.5 * target_sr)  # 0.5 seconds
        if len(audio) < min_sample:
            # audio = torch.nn.functional.pad(audio, (0, min_sample - len(audio)), mode='constant')
            audio = np.pad(audio, (0, min_sample - len(audio)), mode='constant')
        
        # ensure maximun length of audio
        max_sample = int(10 * target_sr)  # 10 seconds
        if len(audio) > max_sample:
            audio = audio[:max_sample]
        
        return audio, target_sr, True

    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None, None, False

In [5]:
def check_model_type(model_name):
    """
    Check if the model is VITS or not
    """
    try:
        model = VitsModel.from_pretrained(model_name)
        logger.info("Using VITS model")
        return model
    except Exception as e:
        logger.error(f"Error loading model {model_name}: {e}")
        raise ValueError("Currently only VITS model is supported.")


In [6]:
def veryify_tokenizer_compatibility(tokenizer, sample_texts):
    """
    Verify that the tokenizer can works with Nepali dataset vocabulary 
    """
    logger.info("Verifying tokenizer compatibility...")

    issues = []
    for text in sample_texts[:5]:  # check first 5 samples
        try:
            tokens = tokenizer(text, return_tensors='pt')
            #check for excessive unknown tokens
            token_ids = tokens['input_ids'].squeeze() 

            #count potential issues
            if len(token_ids) == 0:
                issues.append(f"Empty tokenization for: {text}")
            
            elif len(token_ids) > 150: #very long tokenization might indicate issues
                issues.append(f"Excessively long tokenization ({len(token_ids)} tokens) for: {text}")
            
        except Exception as e:
            logger.error(f"Error processing text '{text}': {e}")
            issues.append(f"Error processing text '{text}': {e}")

    if issues:
        logger.warning(f"Tokenizer compatibility issues found: {issues}")
        for issue in issues:
            logger.warning(f"  - {issue}")
    else:
        logger.info("Tokenizer is compatible with Nepali dataset vocabulary.")

    return issues


In [7]:

## DATASET WITH MEL SPECTROGRAMS

class NepaliVITSDataset(Dataset):
    """
    Dataset with proper mel spectrograms computation 
    """
    def __init__(self, dataframe, tokenizer, max_length=150, sample_rate=16000):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.sample_rate = sample_rate

        #verify tokenizer compatibility
        veryify_tokenizer_compatibility(tokenizer, self.data['sentence'].tolist())

        logger.info(f"Dataset initialized with {len(self.data)} samples.")

    def __len__(self):
        return len(self.data)
    

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        #get text and audio path
        text = str(row['sentence']).strip()
        audio_path = str(row['path'])

        ## audio is already pre-processed, so only getting audio
        audio, sr, _ = preprocess_audio(audio_path, target_sr=self.sample_rate)
        # audio, sr = librosa.load(audio_path, sr=self.sample_rate)

    
        # compute mel spectrogram for proper loss computation
        mel_spec = audio_to_mel_spectrogram(audio, sr=self.sample_rate)

        #tokenize text
        try:
            tokenized = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            )

            input_ids = tokenized['input_ids'].squeeze(0)  # remove batch dimension
            attention_mask = tokenized['attention_mask'].squeeze(0) # remove batch dimension
        
        except Exception as e:
            logger.warning(f"Error tokenizing text '{text}': {e}")
            #fallback
            input_ids = torch.zeros(self.max_length, dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'audio':torch.tensor(audio),
            'mel_spectrogram': mel_spec,
            'text': text,
            'audio_path': audio_path
        }



In [8]:
class VITSDataCollator:
    """
    Data collator for VITS model, that handls mel spectrogram properly
    """
    def __init__(self, tokenizer, max_audio_length=16000*8):
        self.tokenizer = tokenizer
        self.max_audio_length = max_audio_length  #  8 seconds at 16kHz

    def __call__(self, batch):
        #handle variable length audio and mel spectrograms
        audio_lengths = [len(item['audio']) for item in batch]
        max_audio_len = min(max(audio_lengths), self.max_audio_length)

        #handle mel spectrograms (variable time dimension)
        mel_time_lengths = [item['mel_spectrogram'].shape[1] for item in batch]
        max_mel_time = max(mel_time_lengths)

        #pad audio sequences
        padded_audios = []
        padded_mels = []
        actual_lengths = []

        for item in batch:
            #pad audio
            audio = item['audio']
            actual_length = len(audio)

            if actual_length > max_audio_len:
                audio = audio[:max_audio_len]
                actual_length = max_audio_len
            elif actual_length < max_audio_len:
                padding = max_audio_len - actual_length
                audio = torch.cat([audio, torch.zeros(padding)], dim=0)

            padded_audios.append(audio)
            actual_lengths.append(actual_length)

            #pad mel spectrogram
            mel = item['mel_spectrogram'] # shape: (n_mels, time_frames)
            mel_time = mel.shape[1]

            if mel_time < max_mel_time:
                padding = max_mel_time - mel_time
                mel = torch.cat([mel, torch.zeros((mel.shape[0], padding))], dim=1) 
            elif mel_time > max_mel_time:
                mel = mel[:, :max_mel_time]

            padded_mels.append(mel)

        return {
            "input_ids": torch.stack([item['input_ids'] for item in batch]),
            "attention_mask": torch.stack([item['attention_mask'] for item in batch]),
            "audio": torch.stack(padded_audios),
            "mel_spectrogram": torch.stack(padded_mels),
            # "audio_lengths": torch.tensor(actual_lengths),
            "labels": torch.stack([item['input_ids'] for item in batch])  # for VITS, labels are same as input_ids
        }


In [9]:

class VITSTrainerOutput(ModelOutput):
    """
    Wrapper for VITS model outputs compatible with HuggingFace Trainer
    """
    loss: torch.Tensor = None
    waveform: torch.Tensor = None
    logits: torch.Tensor = None  # optional, if your model returns prediction_scores

    def __init__(self, loss=None, waveform=None, logits=None):
        self.loss = loss
        self.waveform = waveform
        self.logits = logits


class VITSFineTuner(nn.Module):
    """
    VITS wrapper with mel spectrogram loss computation 
    """
    def __init__(self, model_name="tuskbyte/nepali_male_v1"):
        super().__init__()

        # check model type and load appropriate base model
        self.base_model = check_model_type(model_name)
        self.model_name = model_name

        #freeze early layers for stable fine-tuning
        self.freeze_early_layers(num_layers_to_freeze=3)

    def freeze_early_layers(self, num_layers_to_freeze=3):
        """ 
        Better layer freezing strategy
        """
        frozen_params = 0
        total_params = 0

        for name, param in self.base_model.named_parameters():
            total_params += 1

            # Freeze text encoder layers (preserve language understanding)
            if "text_encoder" in name and any(f"layers.{i}" in name for i in range(num_layers_to_freeze)):
                param.requires_grad = False
                frozen_params += 1
            
            # Freeze early flow layers (preserve basic acoustic modeling)
            elif "flow" in name and any(f"layers.{i}" in name for i in range(num_layers_to_freeze//2)):
                param.requires_grad = False
                frozen_params += 1

        logger.info(f"Frozen {frozen_params}/{total_params} parameters ({100 * frozen_params / total_params:.2f}%) for stable training.")

    
    def compute_mel_loss(self, generated_audio, target_audio):
        """
        Proper mel spectrogram loss computation 
        """
        try:
            #convert both to mel spectrograms
            generated_mel = []
            target_mel = []

            for gen_audio, tgt_audio in zip(generated_audio, target_audio):
                #ensure same length
                min_len = min(len(gen_audio), len(tgt_audio))
                gen_audio = gen_audio[:min_len]
                tgt_audio = tgt_audio[:min_len]

                #convert to mel spectrogram
                gen_mel = audio_to_mel_spectrogram(gen_audio)
                tgt_mel = audio_to_mel_spectrogram(tgt_audio)

                generated_mel.append(gen_mel)
                target_mel.append(tgt_mel)

            #pad mel spectrograms to same shape
            max_time = max(mel.shape[1] for mel in generated_mel + target_mel)

            padded_gen_mels = []
            padded_tgt_mels = []

            for gen_mel, tgt_mel in zip(generated_mel, target_mel):
                #pad time dimension
                if gen_mel.shape[1] < max_time:
                    gen_mel = torch.cat([gen_mel, torch.zeros((gen_mel.shape[0], max_time - gen_mel.shape[1]))], dim=1)
                if tgt_mel.shape[1] < max_time:
                    tgt_mel = torch.cat([tgt_mel, torch.zeros((tgt_mel.shape[0], max_time - tgt_mel.shape[1]))], dim=1) 
                
                padded_gen_mels.append(gen_mel)
                padded_tgt_mels.append(tgt_mel)

            
            # stack and compute L1 loss
            generated_mels = torch.stack(padded_gen_mels)
            target_mels = torch.stack(padded_tgt_mels)

            # Move to same device
            if generated_audio.is_cuda:
                target_mels = target_mels.cuda()
                generated_mels = generated_mels.cuda()

            mel_loss = F.l1_loss(generated_mels, target_mels)
            return mel_loss


        except Exception as e:
            logger.warning(f"Mel loss computation error: {e}")
            # fallback to simple loss
            return F.mse_loss(generated_audio.mean(), target_audio.mean())
        

    def forward(self, input_ids, attention_mask=None, audio=None, mel_spectrogram=None, **kwargs):
        """
        Forward pass with mel spectrogram loss 
        """

        clean_kwargs = {k: v for k, v in kwargs.items() 
                   if k not in ['audio_lengths', 'audio', 'mel_spectrogram']}
        
        try:
            #get model outputs
            outputs = self.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                **clean_kwargs  
            )

            #compute loss if training data is provided
            loss = None
            if audio is not None:
                if hasattr(outputs, 'waveform') and outputs.waveform is not None:
                    loss = self.compute_mel_loss(outputs.waveform, audio)
                
                elif hasattr(outputs, 'prediction_scores'):
                    # if models that give output mel spectrogram directly
                    if mel_spectrogram is not None:
                        pred_mels = outputs.prediction_scores
                        loss = F.l1_loss(pred_mels, mel_spectrogram)
                    else:
                        # fallback loss 
                        loss = torch.tensor(0.01, requires_grad=True, device=input_ids.device) #dummy loss value for preventing errors in backpropagation
                else:
                    # create learning signal even without direct audio output
                    # use text reconstruction loss as a proxy
                    if hasattr(outputs, 'last_hidden_state'):
                        #simple regularization loss to keep model learning
                        hidden_states = outputs.last_hidden_state
                        loss = 0.01 * torch.mean(hidden_states.pow(2))  # L2 regularization
                    else:
                        loss = torch.tensor(0.01, requires_grad=True, device=input_ids.device) #dummy loss value for preventing errors in backpropagation

            # return outputs and loss
            # class CustomOutput:
            #     def __init_(self, loss, original_outputs):
            #         self.loss = loss
            #         #copy all original outputs attributes
            #         for attr_name in dir(original_outputs):
            #             if not attr_name.startswith("_"):
            #                 setattr(self, attr_name, getattr(original_outputs, attr_name))
            
            waveform = getattr(outputs, "waveform", None)
            logits = getattr(outputs, "prediction_scores", None)

            # Fallback: always ensure a scalar loss exists
            if loss is None:
                loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)

            return VITSTrainerOutput(loss=loss, waveform=waveform, logits=logits)


        except Exception as e:
            logger.warning(f"Forward pass error: {e}")
            # Return stable fallback
            dummy_loss = torch.tensor(0.01, requires_grad=True, device=input_ids.device)

            # class DummyOutput:
            #     def __init__(self, loss):
            #         self.loss = loss

            # return DummyOutput(dummy_loss)

            return VITSTrainerOutput(loss=dummy_loss, waveform=None, logits=None)



In [10]:

## TRAINING ARGUMENTS

def setup_training_arguments(output_dir="./nepali_vits_finetuned"):
    """
    Setup training arguments with proper logging and saving strategies 
    """
    
    return TrainingArguments(
        output_dir=output_dir,

        num_train_epochs=25,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=4,

        #learning rate scheduler
        learning_rate=5e-5,
        weight_decay=0.01,
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,

        #scheduler
        lr_scheduler_type="cosine",
        warmup_steps=200,

        #logging and saving
        logging_dir=f"{output_dir}/logs",
        logging_steps=100,
        save_steps=1000,
        save_total_limit=5,
        evaluation_strategy="steps",
        eval_steps=500,

        #memory and stability improvements
        fp16=False, 
        dataloader_pin_memory=False, # this can be set to True if system has enough memory
        dataloader_num_workers=0, 

        #early stopping and best model 
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,

        #reporting
        report_to='tensorboard',
        run_name=f"nepali_vits_finetune_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )   

In [11]:

## TRAINING FUNCTION

def train_nepali_vits(dataframe, model_name="tuskbyte/nepali_male_v1", output_dir="./nepali_vits_finetuned"):
    """
    Complete training function to fine-tune VITS on Nepali dataset 
    """
    logger.info("Starting Nepali VITS fine-tuning...")
    logger.info("="*60)

    #create output dir
    os.makedirs(output_dir, exist_ok=True)

    # first step: load model and tokenizer
    logger.info("Loading model: %s and tokenizer...", model_name)

    try:
        tokenizer = VitsTokenizer.from_pretrained(model_name)
        model = VITSFineTuner(model_name=model_name)
    except Exception as e:
        logger.error(f"Error loading model or tokenizer: {e}")
        raise

    #second step: prepare datasets with better splitting
    logger.info("Preparing datasets...")
    #shuffle and split
    shuffled_df = dataframe.sample(frac=1, random_state=42).reset_index(drop=True)
    train_size = int(0.85 * len(shuffled_df)) # 85% for training, 15% for validation

    train_df = shuffled_df.iloc[:train_size]
    val_df = shuffled_df.iloc[train_size:]

    logger.info(f"Training samples: {len(train_df)}")
    logger.info(f"Validation samples: {len(val_df)}")

    #create datasets
    train_dataset = NepaliVITSDataset(train_df, tokenizer)
    val_dataset = NepaliVITSDataset(val_df, tokenizer)

    #setup data collator
    data_collator = VITSDataCollator(tokenizer)

    #third step: setup training arguments
    logger.info("Setting up training arguments...")
    training_args = setup_training_arguments(output_dir=output_dir)

    #fourth step: setup trainer with early stopping
    logger.info("Initializing Trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] #stop if no improvement in 5 eval steps
    )
    

    #fifth step: start training
    logger.info("==> Starting training...")
    try:
        #clear cuda cache before training
        torch.cuda.empty_cache()

        trainer.train()
        logger.info("==> Training completed successfully.")

    except Exception as e:
        logger.error(f"Training error: {e}")
        raise


    #final step: save final model and tokenizer
    logger.info("Saving final model and tokenizer...")
    model.base_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)


    # save training info
    training_info = {
        "model_name": model_name,
        "model_type": "VITS",
        "train_samples": len(train_df),
        "val_samples": len(val_df),
        "final_train_loss": trainer.state.log_history[-1].get("train_loss", "unknown"),
        "final_eval_loss": trainer.state.log_history[-1].get("eval_loss", "unknown"),
        "timestamp": datetime.now().isoformat()
    }


    with open(os.path.join(output_dir, "training_info.json"), "w") as f:
        json.dump(training_info, f, indent=4)
    
    logger.info("Fine-tuning completed! Model saved to %s", output_dir)

    return trainer, model, tokenizer


#### Making data ready

In [12]:
## need only one time

# import librosa
# import soundfile as sf
# from glob import glob

# mp3_files = glob("../../data/common_voice/**/clips/*.mp3", recursive=True)

# for mp3_path in mp3_files:
#     wav_path = mp3_path.replace(".mp3", ".wav")

#     # Load with librosa (auto-decoder via audioread)
#     y, sr = librosa.load(mp3_path, sr=16000, mono=True)  # target 16kHz, mono

#     # Save as wav
#     sf.write(wav_path, y, 16000)

# print("Conversion completed!")


In [13]:
df = pd.read_csv("./vits_training_data/train_filelist.txt", sep="|", names=["path", "speaker_id", "sentence"])

# Replace .mp3 with .wav in the path column
df["path"] = df["path"].str.replace(".mp3", ".wav", regex=False)

# Keep only useful columns
df = df[["path", "sentence"]]

# Save back to DataFrame-ready CSV
df.to_csv("train_clean.csv", index=False)

df.head()


Unnamed: 0,path,sentence
0,../../data/common_voice/cv-corpus-20.0-2024-12...,हाम्रो वाक्यहरू गयो त?
1,../../data/common_voice/cv-corpus-21.0-2025-03...,हाम्रो झण्डा फरर .
2,../../data/common_voice/cv-corpus-21.0-2025-03...,त्यस समितिले हामीलाई भेटेर कुराकानी गरेको थियो .
3,../../data/common_voice/cv-corpus-22.0-2025-06...,खाली ठाउँ जोगाउनुपर्छ भन्ने आवाज पनि उठिरहेको छ .
4,../../data/common_voice/cv-corpus-22.0-2025-06...,म मुसाको तर्फबाट दाइ .


In [15]:
train_nepali_vits(df)

INFO:__main__:Starting Nepali VITS fine-tuning...
INFO:__main__:Loading model: tuskbyte/nepali_male_v1 and tokenizer...
INFO:__main__:Using VITS model
INFO:__main__:Frozen 78/762 parameters (10.24%) for stable training.
INFO:__main__:Preparing datasets...
INFO:__main__:Training samples: 2398
INFO:__main__:Validation samples: 424
INFO:__main__:Verifying tokenizer compatibility...
INFO:__main__:Tokenizer is compatible with Nepali dataset vocabulary.
INFO:__main__:Dataset initialized with 2398 samples.
INFO:__main__:Verifying tokenizer compatibility...
INFO:__main__:Tokenizer is compatible with Nepali dataset vocabulary.
INFO:__main__:Dataset initialized with 424 samples.
INFO:__main__:Setting up training arguments...
INFO:__main__:Initializing Trainer...
  trainer = Trainer(
INFO:__main__:==> Starting training...
ERROR:__main__:Training error: The model did not return a loss from the inputs, only the following keys: . For reference, the inputs it received are input_ids,attention_mask,aud

ValueError: The model did not return a loss from the inputs, only the following keys: . For reference, the inputs it received are input_ids,attention_mask,audio,mel_spectrogram,labels.

In [24]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else None)

CUDA available: True
CUDA device count: 1
Current device: 0


### Test 

In [None]:
def test_finetuned_model(model_path, test_text=None):
    """
    Proper inference handling for fine-tuned model 
    """
    if test_text is None:
        test_text = [
            "नमस्कार, म एक नेपाली बोल्छु।",
            "मैले एक पटक कलेज मा कुरा गर्नु पर्छ।", 
            "न्छ भन्ने शब्द बोल्न गाह्रो छ।",
            "आज मौसम राम्रो छ।",
            "धन्यवाद र नमस्कार।"
        ]

    logger.info("Testing fine-tuned model")

    try:
        #load model info
        with open(os.path.join(model_path, "training_info.json"), "r") as f:
            training_info = json.load(f)
            model_type = training_info.get('model_type', 'unknown')

        #load tokenizer and model
        model = VitsModel.from_pretrained(model_path)
        tokenizer = VitsTokenizer.from_pretrained(model_path)

        model.eval()
        logger.info(f"Loaded fine-tuned model (type: {model_type})")

        # generation 
        for i, text in enumerate(test_text):
            try:
                logger.info(f"Generating audio for: {text}")

                inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)

                with torch.no_grad():
                    output = model(**inputs) #forward pass 
                
                audio_generated = False

                if hasattr(output, 'waveform') and output.waveform is not None:
                    #direct waveform output
                    audio = output.waveform.squeeze().cpu().numpy()
                    output_path = f"{model_path}/test_output_{i+1}.wav"
                    sf.write(output_path, audio, 16000)
                    logger.info(f"Audio saved to {output_path}")
                    audio_generated = True

                elif hasattr(output, 'audio') and output.audio is not None:
                    #alternative audio output
                    audio = output.audio.squeeze().cpu().numpy()
                    output_path = f"{model_path}/test_output_{i+1}.wav"
                    sf.write(output_path, audio, 16000)
                    logger.info(f"Audio saved to {output_path}")
                    audio_generated = True

                elif hasattr(output, 'last_hidden_state'):
                    #mel spectrogram output - need vocoder
                    logger.warning(f"Model output is mel spectrogram. Vocoder needed for waveform generation for text: {text}.")
                    logger.info("Use a vocoder like HiFi-GAN or WaveGlow to convert mel spectrogram to waveform.")

                if not audio_generated:
                    logger.warning(f"No valid audio output generated for text: {text}")
                    logger.info(f" Output type: {type(output)}")
                    logger.info(f" Available attributes: {[attr for attr in dir(output) if not attr.startswith('_')]}")


            except Exception as e:
                logger.warning(f"Error generating audio for text '{text}': {e}")


    except Exception as e:
        logger.error(f"Error testing fine-tuned model: {e}")
        