# Fine-Tuning Parler TTS with Your Nepali Voice

This notebook guides you through fine-tuning the Parler TTS model (`ai4bharat/indic-parler-tts`) with your own Nepali voice recordings.

**Pipeline:**
1.  **Install Dependencies:** Set up the required libraries.
2.  **Record Voice Samples:** Use a Gradio interface to record or upload audio samples paired with Nepali text prompts.
3.  **Prepare Dataset:** Preprocess the recorded audio into audio tokens (using DAC) and pair them with text tokens.
4.  **Fine-Tune Model:** Train the Parler TTS model on your custom dataset.
5.  **Generate Speech:** Use the fine-tuned model to synthesize speech in your voice.

## With VITS

In [28]:
# Set environment variable to fall back to CPU for unsupported MPS operations
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import torch
import numpy as np
import pandas as pd
import torchaudio
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from IPython.display import Audio, display

# --- Simplified dataset with fixed-length processing ---
class NepaliTTSDataset(Dataset):
    def __init__(self, metadata_path, audio_dir, max_text_length=100, fixed_audio_length=16000*3):
        """
        Args:
            metadata_path: Path to CSV with text and file columns
            audio_dir: Directory with audio files
            max_text_length: Maximum text length in characters
            fixed_audio_length: Fixed length for audio samples (will be padded/trimmed)
        """
        self.df = pd.read_csv(metadata_path)
        self.audio_dir = audio_dir
        self.max_text_length = max_text_length
        self.fixed_audio_length = fixed_audio_length
        
        # Build vocabulary from all text
        all_text = " ".join(self.df["text"].dropna().tolist())
        self.chars = sorted(list(set(all_text)))
        self.char_to_idx = {c: i+1 for i, c in enumerate(self.chars)}
        self.char_to_idx["<pad>"] = 0  # Padding token
        
        print(f"Loaded dataset with {len(self.df)} entries and {len(self.char_to_idx)} characters")
        
    def __len__(self):
        return len(self.df)
    
    def encode_text(self, text):
        """Convert text to token IDs with fixed length"""
        if not isinstance(text, str):
            ids = [0] * self.max_text_length  # All padding
        else:
            # Convert characters to indices
            ids = [self.char_to_idx.get(c, 0) for c in text[:self.max_text_length]]
            # Pad if needed
            ids = ids + [0] * (self.max_text_length - len(ids))
        return torch.tensor(ids)
    
    def process_audio(self, waveform, sample_rate):
        """Process audio to fixed length and consistent format"""
        # Convert to mono if needed
        if waveform.dim() > 1 and waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True).squeeze(0)
        elif waveform.dim() > 1:
            waveform = waveform.squeeze(0)
            
        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
        
        # Pad or trim to fixed length
        if waveform.shape[0] > self.fixed_audio_length:
            waveform = waveform[:self.fixed_audio_length]
        elif waveform.shape[0] < self.fixed_audio_length:
            padding = torch.zeros(self.fixed_audio_length - waveform.shape[0])
            waveform = torch.cat([waveform, padding])
            
        return waveform
    
    def __getitem__(self, idx):
        """Get standardized sample with fixed shapes"""
        row = self.df.iloc[idx]
        text = row["text"] if pd.notna(row["text"]) else ""
        
        # Process text
        input_ids = self.encode_text(text)
        
        try:
            # Load and process audio
            audio_path = os.path.join(self.audio_dir, row["file"])
            waveform, sample_rate = torchaudio.load(audio_path)
            waveform = self.process_audio(waveform, sample_rate)
            
            # Return sample
            return {
                "input_ids": input_ids,
                "waveform": waveform,
                "text": text
            }
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # Return a dummy sample
            return {
                "input_ids": input_ids,
                "waveform": torch.zeros(self.fixed_audio_length),
                "text": text
            }

# --- Ultra-Simple TTS model with careful dimension handling ---
class UltraSimpleTTSModel(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size=128):
        super().__init__()
        
        # Text encoder - simple embedding
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        
        # Simple feed-forward layers
        self.fc1 = torch.nn.Linear(hidden_size, 256)
        self.fc2 = torch.nn.Linear(256, 512)
        self.fc3 = torch.nn.Linear(512, 1024)
        
        # Activation
        self.relu = torch.nn.ReLU()
        
        # Output projection
        self.output_proj = torch.nn.Linear(1024, 480)  # Project to 480 audio samples per character
        
        # Loss function
        self.loss_fn = torch.nn.MSELoss()
        
    def forward(self, input_ids, waveform=None):
        # Get batch size and sequence length
        batch_size, seq_len = input_ids.shape
        
        # Encode text
        embedded = self.embedding(input_ids)  # [batch_size, seq_len, hidden_size]
        
        # Pass through feed-forward layers
        x = self.relu(self.fc1(embedded))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        
        # Project to audio samples
        x = self.output_proj(x)  # [batch_size, seq_len, 480]
        
        # Reshape to match waveform length (each token generates 480 samples)
        # This creates a [batch_size, seq_len * 480] tensor
        audio_output = x.reshape(batch_size, seq_len * 480)
        
        # Trim to fixed length of 48000 samples (3s at 16kHz)
        # or pad if needed
        if audio_output.shape[1] > 48000:
            audio_output = audio_output[:, :48000]
        elif audio_output.shape[1] < 48000:
            padding = torch.zeros(batch_size, 48000 - audio_output.shape[1], device=audio_output.device)
            audio_output = torch.cat([audio_output, padding], dim=1)
        
        # Compute loss if targets provided
        loss = None
        if waveform is not None:
            loss = self.loss_fn(audio_output, waveform)
            
        return {
            "waveform": audio_output,
            "loss": loss
        }

# --- Training function ---
def train_nepali_tts_simple():
    # Configuration
    DATA_DIR = "voice_data"
    METADATA_FILE = "metadata.csv"
    OUTPUT_DIR = "nepali_tts_model_simple"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Fixed shapes for our data
    MAX_TEXT_LENGTH = 100  # Maximum text length in characters
    FIXED_AUDIO_LENGTH = 48000  # 3 seconds at 16kHz
    
    # Create dataset with fixed shapes
    metadata_path = os.path.join(DATA_DIR, METADATA_FILE)
    dataset = NepaliTTSDataset(
        metadata_path, 
        DATA_DIR,
        max_text_length=MAX_TEXT_LENGTH,
        fixed_audio_length=FIXED_AUDIO_LENGTH
    )
    
    # Create model
    model = UltraSimpleTTSModel(len(dataset.char_to_idx))
    
    # Create dataloader
    batch_size = 2  # Start small
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    
    # Check a single batch to confirm shapes
    print("\nValidating shapes:")
    sample_batches = []
    for batch in train_loader:
        sample_batches.append(batch)
        if len(sample_batches) >= 2:
            break
    
    # Test first batch
    batch = sample_batches[0]
    input_ids = batch["input_ids"]
    waveform = batch["waveform"]
    print(f"Input IDs shape: {input_ids.shape}")
    print(f"Waveform shape: {waveform.shape}")
    
    # Test forward pass
    print("\nTesting forward pass...")
    outputs = model(input_ids)
    print(f"Model output shape: {outputs['waveform'].shape}")
    
    # Test loss computation
    print("\nTesting loss computation...")
    outputs_with_loss = model(input_ids, waveform)
    print(f"Loss: {outputs_with_loss['loss'].item():.4f}")
    
    # Setup training
    device = torch.device("cpu")  # Start with CPU to avoid MPS issues
    print(f"\nTraining on device: {device}")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    num_epochs = 5  # Just a few epochs for testing
    
    print(f"\nStarting training for {num_epochs} epochs...")
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for i, batch in enumerate(train_loader):
            # Get data
            input_ids = batch["input_ids"].to(device)
            waveform = batch["waveform"].to(device)
            
            # Forward and backward pass
            optimizer.zero_grad()
            outputs = model(input_ids, waveform)
            loss = outputs["loss"]
            loss.backward()
            optimizer.step()
            
            # Track progress
            running_loss += loss.item()
            if i % 5 == 0:  # Log every 5 batches
                print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item():.4f}")
        
        # End of epoch
        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} complete, Average Loss: {epoch_loss:.4f}")
        
        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
        }, os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch+1}.pt"))
        
    # Save final model and vocabulary
    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "final_model.pt"))
    
    # Save vocabulary
    import json
    with open(os.path.join(OUTPUT_DIR, "vocab.json"), "w") as f:
        json.dump(dataset.char_to_idx, f)
    
    print(f"Training complete! Model saved to {OUTPUT_DIR}")
    return model, dataset

# --- Inference function ---
def generate_speech_simple(model, dataset, text, device="cpu"):
    """Generate speech from text using the simple model"""
    model.eval()
    model.to(device)
    
    # Process text
    input_ids = dataset.encode_text(text).unsqueeze(0).to(device)
    
    # Generate waveform
    with torch.no_grad():
        outputs = model(input_ids)
        waveform = outputs["waveform"].cpu().numpy()[0]
    
    # Normalize audio
    waveform = waveform / (np.max(np.abs(waveform)) + 1e-6)
    
    # Play audio
    display(Audio(waveform, rate=16000))
    
    return waveform

# Run training with proper error handling
try:
    model, dataset = train_nepali_tts_simple()
    
    # Test generation
    test_text = "नमस्ते, यो नेपाली टीटीएस नमुना हो"
    waveform = generate_speech_simple(model, dataset, test_text)
except Exception as e:
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()

Loaded dataset with 132 entries and 55 characters

Validating shapes:
Input IDs shape: torch.Size([2, 100])
Waveform shape: torch.Size([2, 48000])

Testing forward pass...
Model output shape: torch.Size([2, 48000])

Testing loss computation...
Loss: 0.0041

Training on device: cpu

Starting training for 5 epochs...
Epoch 1, Batch 1, Loss: 0.0051
Epoch 1, Batch 6, Loss: 0.0042
Epoch 1, Batch 11, Loss: 0.0047
Epoch 1, Batch 16, Loss: 0.0041
Epoch 1, Batch 21, Loss: 0.0038
Epoch 1, Batch 26, Loss: 0.0021
Epoch 1, Batch 31, Loss: 0.0045
Epoch 1, Batch 36, Loss: 0.0019
Epoch 1, Batch 41, Loss: 0.0015
Epoch 1, Batch 46, Loss: 0.0049
Epoch 1, Batch 51, Loss: 0.0022
Epoch 1, Batch 56, Loss: 0.0027
Epoch 1, Batch 61, Loss: 0.0027
Epoch 1, Batch 66, Loss: 0.0034
Epoch 1 complete, Average Loss: 0.0035
Epoch 2, Batch 1, Loss: 0.0038
Epoch 2, Batch 6, Loss: 0.0042
Epoch 2, Batch 11, Loss: 0.0045
Epoch 2, Batch 16, Loss: 0.0015
Epoch 2, Batch 21, Loss: 0.0014
Epoch 2, Batch 26, Loss: 0.0048
Epoch 2,

In [None]:
# Install required packages - run this first
!pip install gtts pygame librosa soundfile

Collecting TTS==0.21.0
  Downloading TTS-0.21.0.tar.gz (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: TTS
  Building wheel for TTS (pyproject.toml) ... [?25ldone
[?25h  Created wheel for TTS: filename=tts-0.21.0-cp311-cp311-macosx_15_0_arm64.whl size=899340 sha256=4f7daa5bddf614237f19f7786eb91920d71710f1bc8e4e5aac9ff40b8df4ba3f
  Stored in directory: /Users/jeevanbhatta/Library/Caches/pip/wheels/8f/52/bf/ea322650816f5cb394bbfd87b1b400418f08b6db4067a98576
Successfully built TTS
Installing collected packages: TTS
  Attempting uninstall: TTS
    Found existing installation: TTS 0.22.0
    Uninstalling TTS-0.22.0:
      Successfully uninstalled TTS-0.22.0
Successfully installed TTS-0.21.0

[1m[[0m[

# Using Coqui TTS

In [None]:
# Simplified Nepali TTS - focusing on what works

import os
import numpy as np
from IPython.display import Audio, display
import time

# === APPROACH 1: Google TTS for Nepali (most reliable) ===
def generate_with_gtts(text, output_file="gtts_output.wav"):
    """Generate Nepali speech using Google TTS"""
    try:
        from gtts import gTTS
        print(f"Generating with Google TTS: {text}")
        
        # Use Nepali language code
        tts = gTTS(text=text, lang='ne', slow=False)
        tts.save(output_file)
        
        print(f"Audio saved to {output_file}")
        display(Audio(output_file))
        return output_file
    except Exception as e:
        print(f"Google TTS error: {e}")
        import traceback
        traceback.print_exc()
        return None

# === APPROACH 2: Facebook's MMS model for Nepali (best quality) ===
def generate_with_mms(text, output_file="mms_output.wav"):
    """Generate speech using Facebook MMS model which supports Nepali natively"""
    try:
        print("Setting up Facebook MMS model for Nepali...")
        
        # Install requirements
        !pip install -q torch transformers datasets

        # Import after installation
        from transformers import AutoProcessor, AutoModelForTextToSpeech
        import torch
        import soundfile as sf
        
        # Load the Nepali MMS model
        model_name = "facebook/mms-tts-nep"
        processor = AutoProcessor.from_pretrained(model_name)
        model = AutoModelForTextToSpeech.from_pretrained(model_name)
        
        # Process text and generate speech
        inputs = processor(text=text, return_tensors="pt")
        
        with torch.no_grad():
            output = model(**inputs).waveform
        
        # Convert to numpy and save
        audio = output.squeeze().numpy()
        sf.write(output_file, audio, samplerate=16000)
        
        print(f"MMS output saved to {output_file}")
        display(Audio(output_file))
        return output_file
    except Exception as e:
        print(f"MMS model error: {e}")
        import traceback
        traceback.print_exc()
        return None

# === APPROACH 3: Last resort transliteration + English TTS ===
def generate_with_transliteration(text, output_file="transliterated_output.wav"):
    """
    Transliterate Nepali to Latin script and use English TTS
    This is a fallback if other methods fail
    """
    try:
        print("Attempting transliteration fallback...")
        
        # Simple mapping from Devanagari to Latin
        # This is very simplified and not linguistically accurate
        devanagari_to_latin = {
            'अ': 'a', 'आ': 'aa', 'इ': 'i', 'ई': 'ee', 'उ': 'u', 'ऊ': 'oo',
            'ए': 'e', 'ऐ': 'ai', 'ओ': 'o', 'औ': 'au', 'क': 'ka', 'ख': 'kha',
            'ग': 'ga', 'घ': 'gha', 'ङ': 'nga', 'च': 'cha', 'छ': 'chha',
            'ज': 'ja', 'झ': 'jha', 'ञ': 'nya', 'ट': 'ta', 'ठ': 'tha',
            'ड': 'da', 'ढ': 'dha', 'ण': 'na', 'त': 'ta', 'थ': 'tha',
            'द': 'da', 'ध': 'dha', 'न': 'na', 'प': 'pa', 'फ': 'pha',
            'ब': 'ba', 'भ': 'bha', 'म': 'ma', 'य': 'ya', 'र': 'ra',
            'ल': 'la', 'व': 'wa', 'श': 'sha', 'ष': 'sha', 'स': 'sa',
            'ह': 'ha', 'क्ष': 'ksha', 'त्र': 'tra', 'ज्ञ': 'gya',
            '०': '0', '१': '1', '२': '2', '३': '3', '४': '4',
            '५': '5', '६': '6', '७': '7', '८': '8', '९': '9',
            ' ': ' ', ',': ',', '।': '.', '?': '?', '!': '!'
        }
        
        # Simple transliteration
        transliterated = ''
        for char in text:
            transliterated += devanagari_to_latin.get(char, char)
        
        print(f"Transliterated text: {transliterated}")
        
        # Use gTTS with English for the transliterated text
        from gtts import gTTS
        tts = gTTS(text=transliterated, lang='en', slow=False)
        tts.save(output_file)
        
        print(f"Transliterated audio saved to {output_file}")
        display(Audio(output_file))
        return output_file
    except Exception as e:
        print(f"Transliteration error: {e}")
        import traceback
        traceback.print_exc()
        return None

# Main execution
def main():
    # Test texts
    test_texts = [
        "नमस्ते, यो नेपाली टीटीएस नमुना हो",
        "मेरो नाम टीटीएस हो, म नेपाली पनि बोल्न सक्छु।",
        "यो आवाज कम्प्युटरले बनाएको हो।"
    ]
    
    # Try Google TTS first (most reliable)
    print("\n=== TRYING GOOGLE TTS FOR NEPALI ===")
    gtts_results = []
    for i, text in enumerate(test_texts):
        result = generate_with_gtts(text, f"gtts_output_{i+1}.wav")
        if result:
            gtts_results.append(result)
    
    # If Google TTS failed, try MMS
    if not gtts_results:
        print("\n=== TRYING FACEBOOK MMS MODEL FOR NEPALI ===")
        mms_results = []
        for i, text in enumerate(test_texts[:1]):  # Just try first text as MMS is slower
            result = generate_with_mms(text, f"mms_output_{i+1}.wav")
            if result:
                mms_results.append(result)
    
        # If both failed, try transliteration
        if not mms_results:
            print("\n=== TRYING TRANSLITERATION FALLBACK ===")
            for i, text in enumerate(test_texts):
                generate_with_transliteration(text, f"trans_output_{i+1}.wav")
    
    print("\nSpeech generation complete!")

# Run the main function in a safe way
try:
    main()
except Exception as e:
    print(f"Error in main function: {e}")
    import traceback
    traceback.print_exc()

Collecting gtts
  Downloading gTTS-2.5.4-py3-none-any.whl.metadata (4.1 kB)
Collecting pygame
  Downloading pygame-2.6.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (12 kB)
Downloading gTTS-2.5.4-py3-none-any.whl (29 kB)
Downloading pygame-2.6.1-cp311-cp311-macosx_11_0_arm64.whl (12.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25hInstalling collected packages: pygame, gtts
Successfully installed gtts-2.5.4 pygame-2.6.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

=== TRYING GOOGLE TTS FOR NEPALI ===
Generating with Google TTS: नमस्ते, यो नेपाली टीटीएस नमुना हो
Audio saved to gtts_output_1.wav


Generating with Google TTS: मेरो नाम टीटीएस हो, म नेपाली पनि बोल्न सक्छु।
Audio saved to gtts_output_2.wav


Generating with Google TTS: यो आवाज कम्प्युटरले बनाएको हो।
Audio saved to gtts_output_3.wav



Speech generation complete!


Will use this tts as a comparison with the fine tuned audio.

In [36]:
# Install required packages
!pip install -q torch torchaudio numpy pandas matplotlib tqdm librosa soundfile nltk


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [39]:
# Nepali TTS Fine-tuning with Your Voice Data
# For CPU training on MacBook (8GB RAM)
# Simplified version without tqdm dependency

import os
import numpy as np
import pandas as pd
import torch
import librosa
import soundfile as sf
from IPython.display import Audio, display
import matplotlib.pyplot as plt
import shutil
import subprocess
import sys

# Install required packages
!pip install torch torchaudio numpy pandas matplotlib librosa soundfile gtts

# === 1. DATA PREPARATION ===
class VoiceDataPrep:
    def __init__(self, data_path, output_path):
        self.data_path = data_path
        self.output_path = output_path
        self.metadata_file = os.path.join(data_path, "metadata.csv")
        
        # Create Mozilla TTS format directory structure
        os.makedirs(os.path.join(output_path, "wavs"), exist_ok=True)
        
    def prepare_data(self):
        """Convert your dataset to LJSpeech-like format for Mozilla TTS"""
        print("Preparing Nepali dataset...")
        
        # Read metadata
        df = pd.read_csv(self.metadata_file)
        
        # Create new metadata files
        train_file = os.path.join(self.output_path, "metadata_train.csv")
        val_file = os.path.join(self.output_path, "metadata_val.csv")
        
        with open(train_file, "w", encoding="utf-8") as f_train, \
             open(val_file, "w", encoding="utf-8") as f_val:
            
            # Write headers
            f_train.write("file_path|text\n")
            f_val.write("file_path|text\n")
            
            # Process each entry
            val_indices = set(np.random.choice(
                range(len(df)), 
                size=min(5, max(1, len(df) // 10)),  # 10% for validation, min 1, max 5
                replace=False
            ))
            
            print(f"Total entries: {len(df)}")
            print(f"Processing audio files...")
            
            train_count = 0
            val_count = 0
            error_count = 0
            
            # Simple progress tracking
            for idx, row in df.iterrows():
                if idx % 10 == 0:
                    print(f"Processing entry {idx}/{len(df)}...", end="\r")
                    
                if pd.isna(row["text"]) or not os.path.exists(os.path.join(self.data_path, row["file"])):
                    continue
                    
                # Process audio file
                src_path = os.path.join(self.data_path, row["file"])
                file_id = f"nepali_{idx:05d}"
                dst_path = os.path.join("wavs", f"{file_id}.wav")
                dst_full_path = os.path.join(self.output_path, dst_path)
                
                try:
                    # Load, normalize and save audio at 22050Hz for better training
                    y, sr = librosa.load(src_path, sr=22050, mono=True)
                    
                    # Normalize audio
                    y = y / (np.max(np.abs(y)) + 1e-6)
                    
                    # Trim silence
                    y, _ = librosa.effects.trim(y, top_db=20)
                    
                    # Save as WAV
                    sf.write(dst_full_path, y, 22050, subtype='PCM_16')
                    
                    # Write metadata line
                    line = f"{dst_path}|{row['text']}\n"
                    
                    if idx in val_indices:
                        f_val.write(line)
                        val_count += 1
                    else:
                        f_train.write(line)
                        train_count += 1
                        
                except Exception as e:
                    print(f"Error processing {row['file']}: {e}")
                    error_count += 1
        
        print(f"\nDataset prepared at {self.output_path}")
        print(f"Train entries: {train_count}")
        print(f"Validation entries: {val_count}")
        print(f"Errors: {error_count}")
        
        return self.output_path

# === 2. FALLBACK: USE GTTS FOR NOW + INSTRUCTIONS FOR FINE-TUNING ===
def generate_with_gtts(text, output_file="gtts_output.wav"):
    """Generate Nepali speech using Google TTS"""
    from gtts import gTTS
    print(f"Generating with Google TTS: {text}")
    tts = gTTS(text=text, lang='ne', slow=False)
    tts.save(output_file)
    display(Audio(output_file))
    return output_file

# === MAIN EXECUTION ===
def main():
    # Setup paths
    data_path = "voice_data"  # Your original data directory
    prepped_data_path = "nepali_tts_data"  # Directory for prepared data
    
    # 1. Prepare data
    print("=== PREPARING VOICE DATA ===")
    data_prep = VoiceDataPrep(data_path, prepped_data_path)
    dataset_path = data_prep.prepare_data()
    
    # 2. Check hardware capability
    print("\n=== CHECKING SYSTEM CAPABILITIES ===")
    gpu_available = torch.cuda.is_available()
    cpu_count = os.cpu_count()
    
    print(f"GPU available: {gpu_available}")
    print(f"CPU cores: {cpu_count}")
    print(f"PyTorch version: {torch.__version__}")
    
    # 3. Create configuration for training
    print("\n=== CREATING TRAINING CONFIGURATION ===")
    
    # Create config file for training
    config = {
        "run_name": "nepali_tts",
        "run_description": "Nepali TTS with personal voice data",
        
        # Audio processing
        "audio": {
            "sample_rate": 22050,
            "do_trim_silence": True,
            "trim_db": 60,
            "signal_norm": False,
            "min_level_db": -100,
            "ref_level_db": 20,
            "preemphasis": 0.97,
            "symmetric_norm": False,
            "max_norm": 4.0,
            "clip_norm": False,
            "mel_fmin": 0.0,
            "mel_fmax": 8000.0,
            "spec_gain": 1.0,
            "num_mels": 80,
            "hop_length": 256,
            "win_length": 1024,
            "frame_length_ms": None,
            "frame_shift_ms": None
        },
        
        # Dataset
        "datasets": [
            {
                "name": "nepali_dataset",
                "path": prepped_data_path,
                "meta_file_train": "metadata_train.csv",
                "meta_file_val": "metadata_val.csv"
            }
        ],
        
        # Training parameters
        "batch_size": 8,
        "eval_batch_size": 4,
        "num_loader_workers": 0,  # Better for macOS
        "num_eval_loader_workers": 0,
        "epochs": 100,
        "scheduler_after_epoch": True,
        "lr": 0.001,
        "wd": 0.0,
        "warmup_steps": 0,
        "grad_clip": 1.0,
        
        # Model parameters (reduced for CPU training)
        "model": "tacotron",
        "model_params": {
            "num_chars": 100,  # Will be set based on your data
            "decoder_output_dim": 80,  # Same as num_mels
            "encoder_hidden_size": 64,  # Reduced from 256
            "decoder_hidden_size": 128,  # Reduced from 512
            "attention_hidden_size": 64,  # Reduced from 128
            "postnet_hidden_size": 64,  # Reduced from 256
            "prenet_hidden_size": 64,  # Reduced from 256
            "prenet_dropout": 0.5,
            "attention_type": "graves",
            "attention_norm": "softmax",
            "window_backward": 1,
            "window_forward": 2
        },
        
        # Logging, checkpoints, etc.
        "save_step": 1000,
        "print_step": 50,
        "print_eval": True,
        "run_eval": True,
        "use_phonemes": False,
        "phoneme_language": "en-us"
    }
    
    # Save config
    import json
    config_dir = "nepali_tts_config"
    os.makedirs(config_dir, exist_ok=True)
    config_path = os.path.join(config_dir, "config.json")
    
    with open(config_path, "w") as f:
        json.dump(config, f, indent=4)
    
    print(f"Training configuration saved to {config_path}")
        
    # 4. Create training instructions
    print("\n=== CREATING TRAINING INSTRUCTIONS ===")
    
    # Generate instructions file
    instructions_file = "nepali_tts_instructions.txt"
    with open(instructions_file, "w") as f:
        f.write("=== Nepali TTS Fine-tuning Instructions ===\n\n")
        f.write("Your data has been prepared and is ready for training.\n\n")
        
        f.write("1. Data location:\n")
        f.write(f"   - Prepared data: {os.path.abspath(prepped_data_path)}\n")
        f.write(f"   - Config file: {os.path.abspath(config_path)}\n\n")
        
        f.write("2. To train on this computer (CPU):\n")
        f.write("   This will take many hours and may not produce optimal results.\n")
        f.write("   - Install Mozilla TTS: pip install TTS\n")
        f.write("   - Run this command:\n")
        f.write(f"     python -m TTS.bin.train_tts --config_path {os.path.abspath(config_path)} --coqpit.output_path ./nepali_tts_output\n\n")
        
        f.write("3. For MUCH better results with GPU:\n")
        f.write("   - Upload your prepared data to Google Colab\n")
        f.write("   - Use the notebook provided: nepali_tts_colab.ipynb\n\n")
        
        f.write("4. In the meantime, you can use Google TTS:\n")
        f.write("   from gtts import gTTS\n")
        f.write("   tts = gTTS(\"नमस्ते, यो नेपाली टीटीएस नमुना हो\", lang='ne')\n")
        f.write("   tts.save(\"output.wav\")\n")
    
    print(f"Instructions saved to {instructions_file}")
    
    # 5. Create Colab notebook
    print("\n=== CREATING GOOGLE COLAB NOTEBOOK ===")
    colab_notebook = "nepali_tts_colab.ipynb"
    
    notebook_content = '''{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Nepali TTS Fine-tuning with Google Colab\\n",
    "\\n",
    "This notebook helps you fine-tune a Nepali TTS model with your voice data using Google Colab\'s GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages\\n",
    "!pip install TTS matplotlib pandas librosa soundfile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Upload your prepared data\\n",
    "from google.colab import files\\n",
    "import os\\n",
    "import zipfile\\n",
    "\\n",
    "# Create directory structure\\n",
    "!mkdir -p nepali_data/wavs\\n",
    "\\n",
    "print(\\"Please upload your prepared metadata files (metadata_train.csv, metadata_val.csv)\\")\\n",
    "uploaded_metadata = files.upload()\\n",
    "\\n",
    "for filename in uploaded_metadata.keys():\\n",
    "    with open(os.path.join(\\"nepali_data\\", filename), \\"wb\\") as f:\\n",
    "        f.write(uploaded_metadata[filename])\\n",
    "    print(f\\"Saved {filename} to nepali_data/\\")\\n",
    "\\n",
    "print(\\"\\nPlease upload your WAV files as a ZIP archive\\")\\n",
    "uploaded_wavs = files.upload()\\n",
    "\\n",
    "for filename in uploaded_wavs.keys():\\n",
    "    if filename.endswith(\\".zip\\"):\\n",
    "        with zipfile.ZipFile(filename, \\"r\\") as zip_ref:\\n",
    "            zip_ref.extractall(\\"nepali_data\\")\\n",
    "        print(f\\"Extracted {filename} to nepali_data/\\")\\n",
    "\\n",
    "# List files to confirm upload\\n",
    "!ls -la nepali_data\\n",
    "!ls -la nepali_data/wavs | head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the training configuration\\n",
    "import json\\n",
    "\\n",
    "config = {\\n",
    "    \\"run_name\\": \\"nepali_tts\\",\\n",
    "    \\"run_description\\": \\"Nepali TTS with personal voice data\\",\\n",
    "    \\n",
    "    # Audio processing\\n",
    "    \\"audio\\": {\\n",
    "        \\"sample_rate\\": 22050,\\n",
    "        \\"do_trim_silence\\": True,\\n",
    "        \\"trim_db\\": 60,\\n",
    "        \\"signal_norm\\": False,\\n",
    "        \\"min_level_db\\": -100,\\n",
    "        \\"ref_level_db\\": 20,\\n",
    "        \\"preemphasis\\": 0.97,\\n",
    "        \\"symmetric_norm\\": False,\\n",
    "        \\"max_norm\\": 4.0,\\n",
    "        \\"clip_norm\\": False,\\n",
    "        \\"mel_fmin\\": 0.0,\\n",
    "        \\"mel_fmax\\": 8000.0,\\n",
    "        \\"spec_gain\\": 1.0,\\n",
    "        \\"num_mels\\": 80,\\n",
    "        \\"hop_length\\": 256,\\n",
    "        \\"win_length\\": 1024\\n",
    "    },\\n",
    "    \\n",
    "    # Dataset\\n",
    "    \\"datasets\\": [\\n",
    "        {\\n",
    "            \\"name\\": \\"nepali_dataset\\",\\n",
    "            \\"path\\": \\"nepali_data\\",\\n",
    "            \\"meta_file_train\\": \\"metadata_train.csv\\",\\n",
    "            \\"meta_file_val\\": \\"metadata_val.csv\\"\\n",
    "        }\\n",
    "    ],\\n",
    "    \\n",
    "    # Training parameters - optimized for Colab GPU\\n",
    "    \\"batch_size\\": 32,\\n",
    "    \\"eval_batch_size\\": 16,\\n",
    "    \\"num_loader_workers\\": 4,\\n",
    "    \\"num_eval_loader_workers\\": 4,\\n",
    "    \\"epochs\\": 1000,\\n",
    "    \\"scheduler_after_epoch\\": True,\\n",
    "    \\"lr\\": 0.001,\\n",
    "    \\"wd\\": 0.0,\\n",
    "    \\"warmup_steps\\": 0,\\n",
    "    \\"grad_clip\\": 1.0,\\n",
    "    \\n",
    "    # Model parameters - full size for GPU training\\n",
    "    \\"model\\": \\"tacotron\\",\\n",
    "    \\"model_params\\": {\\n",
    "        \\"num_chars\\": 100,\\n",
    "        \\"decoder_output_dim\\": 80,\\n",
    "        \\"encoder_hidden_size\\": 256,\\n",
    "        \\"decoder_hidden_size\\": 512,\\n",
    "        \\"attention_hidden_size\\": 128,\\n",
    "        \\"postnet_hidden_size\\": 256,\\n",
    "        \\"prenet_hidden_size\\": 256,\\n",
    "        \\"prenet_dropout\\": 0.5,\\n",
    "        \\"attention_type\\": \\"graves\\",\\n",
    "        \\"attention_norm\\": \\"softmax\\",\\n",
    "        \\"window_backward\\": 1,\\n",
    "        \\"window_forward\\": 2\\n",
    "    },\\n",
    "    \\n",
    "    # Logging, checkpoints, etc.\\n",
    "    \\"save_step\\": 1000,\\n",
    "    \\"print_step\\": 50,\\n",
    "    \\"print_eval\\": True,\\n",
    "    \\"run_eval\\": True,\\n",
    "    \\"use_phonemes\\": False,\\n",
    "    \\"phoneme_language\\": \\"en-us\\"\\n",
    "}\\n",
    "\\n",
    "with open(\\"config.json\\", \\"w\\") as f:\\n",
    "    json.dump(config, f, indent=4)\\n",
    "    \\n",
    "print(\\"Config file created: config.json\\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check GPU availability and start training\\n",
    "!nvidia-smi\\n",
    "\\n",
    "print(\\"\\nStarting TTS training...\\")\\n",
    "!python -m TTS.bin.train_tts --config_path config.json --coqpit.output_path ./nepali_tts_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the trained model\\n",
    "import torch\\n",
    "from TTS.tts.models import setup_model\\n",
    "from TTS.config.shared_configs import load_config\\n",
    "from TTS.tts.utils.synthesis import synthesis\\n",
    "from TTS.utils.audio import AudioProcessor\\n",
    "from IPython.display import Audio\\n",
    "\\n",
    "# Find the best model checkpoint\\n",
    "import glob\\n",
    "import os\\n",
    "\\n",
    "# Look for the best model or the latest checkpoint\\n",
    "best_model = None\\n",
    "best_model_path = os.path.join(\\"nepali_tts_output\\", \\"best_model.pth\\")\\n",
    "if os.path.exists(best_model_path):\\n",
    "    best_model = best_model_path\\n",
    "else:\\n",
    "    checkpoints = glob.glob(\\"nepali_tts_output/checkpoint_*.pth\\")\\n",
    "    if checkpoints:\\n",
    "        best_model = sorted(checkpoints)[-1]  # Get the latest checkpoint\\n",
    "\\n",
    "if not best_model:\\n",
    "    print(\\"No model checkpoints found. Training may not have completed.\\")\\n",
    "else:\\n",
    "    print(f\\"Using model checkpoint: {best_model}\\")\\n",
    "    \\n",
    "    # Load the config\\n",
    "    config_path = \\"config.json\\"\\n",
    "    config = load_config(config_path)\\n",
    "    \\n",
    "    # Load the model\\n",
    "    model = setup_model(config)\\n",
    "    checkpoint = torch.load(best_model, map_location=torch.device('cpu'))\\n",
    "    model.load_state_dict(checkpoint[\\"model\\"])\\n",
    "    model.eval()\\n",
    "    \\n",
    "    # Setup audio processor\\n",
    "    ap = AudioProcessor(**config.audio.to_dict())\\n",
    "    \\n",
    "    # Test sentences\\n",
    "    test_sentences = [\\n",
    "        \\"नमस्ते, यो नेपाली टीटीएस नमुना हो\\",\\n",
    "        \\"मेरो नाम टीटीएस हो, म नेपाली पनि बोल्न सक्छु।\\",\\n",
    "        \\"यो आवाज कम्प्युटरले बनाएको हो।\\"\\n",
    "    ]\\n",
    "    \\n",
    "    # Generate and play audio\\n",
    "    for i, text in enumerate(test_sentences):\\n",
    "        print(f\\"Generating: {text}\\")\\n",
    "        \\n",
    "        outputs = synthesis(\\n",
    "            model,\\n",
    "            text,\\n",
    "            config,\\n",
    "            use_cuda=False,\\n",
    "            ap=ap,\\n",
    "            speaker_id=None,\\n",
    "            style_wav=None,\\n",
    "            use_griffin_lim=True,\\n",
    "            do_trim_silence=True\\n",
    "        )\\n",
    "        \\n",
    "        # Save audio and play\\n",
    "        file_path = f\\"nepali_output_{i+1}.wav\\"\\n",
    "        ap.save_wav(outputs[\\"wav\\"], file_path)\\n",
    "        display(Audio(file_path, rate=ap.sample_rate))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the trained model and output samples\\n",
    "from google.colab import files\\n",
    "\\n",
    "# Download generated samples\\n",
    "for i in range(len(test_sentences)):\\n",
    "    if os.path.exists(f\\"nepali_output_{i+1}.wav\\"):\\n",
    "        files.download(f\\"nepali_output_{i+1}.wav\\")\\n",
    "\\n",
    "# Create a ZIP with the model files for download\\n",
    "!zip -r nepali_tts_model.zip nepali_tts_output\\n",
    "files.download(\\"nepali_tts_model.zip\\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}'''
    
    with open(colab_notebook, "w") as f:
        f.write(notebook_content)
    
    print(f"Google Colab notebook created: {colab_notebook}")
    print("Upload this notebook to Google Colab for training with GPU.")
    
    # 6. For immediate use: generate samples with Google TTS
    print("\n=== USING GOOGLE TTS FOR IMMEDIATE RESULTS ===")
    print("Generating sample Nepali speech with Google TTS:")
    
    test_texts = [
        "नमस्ते, यो नेपाली टीटीएस नमुना हो",
        "मेरो नाम टीटीएस हो, म नेपाली पनि बोल्न सक्छु।",
        "यो आवाज कम्प्युटरले बनाएको हो।"
    ]
    
    for i, text in enumerate(test_texts):
        generate_with_gtts(text, f"gtts_nepali_{i+1}.wav")
    
    print("\n=== PROCESS COMPLETE ===")
    print("1. Your voice data has been prepared for training")
    print("2. Training configuration has been created")
    print("3. Google Colab notebook is ready for GPU training")
    print("4. Instructions saved to nepali_tts_instructions.txt")
    print("5. Google TTS samples have been generated for immediate use")
    
    return True

# Run the main function
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
=== PREPARING VOICE DATA ===
Preparing Nepali dataset...
Total entries: 132
Processing audio files...
Processing entry 130/132...
Dataset prepared at nepali_tts_data
Train entries: 127
Validation entries: 5
Errors: 0

=== CHECKING SYSTEM CAPABILITIES ===
GPU available: False
CPU cores: 8
PyTorch version: 2.6.0

=== CREATING TRAINING CONFIGURATION ===
Training configuration saved to nepali_tts_config/config.json

=== CREATING TRAINING INSTRUCTIONS ===
Instructions saved to nepali_tts_instructions.txt

=== CREATING GOOGLE COLAB NOTEBOOK ===
Google Colab notebook created: nepali_tts_colab.ipynb
Upload this notebook to Google Colab for training with GPU.

=== USING GOOGLE TTS FOR IMMEDIATE RESULTS ===
Generating sample Nepali speech with Goo

Generating with Google TTS: मेरो नाम टीटीएस हो, म नेपाली पनि बोल्न सक्छु।


Generating with Google TTS: यो आवाज कम्प्युटरले बनाएको हो।



=== PROCESS COMPLETE ===
1. Your voice data has been prepared for training
2. Training configuration has been created
3. Google Colab notebook is ready for GPU training
4. Instructions saved to nepali_tts_instructions.txt
5. Google TTS samples have been generated for immediate use


# Facebook MMS- TTS

Set TOKENIZERS_PARALLELISM=false
Ensure you have the necessary libraries installed:
pip install torch torchaudio
pip install transformers datasets accelerate sentencepiece soundfile
------------------------------
Multiprocessing start method already 'spawn'.
Loading processor and model from 'facebook/mms-tts-nep'...
Error loading model or processor from facebook/mms-tts-nep: facebook/mms-tts-nep is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
Ensure you have internet connectivity and the model ID is correct.
You might need `pip install sentencepiece` for the tokenizer.


Traceback (most recent call last):
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py", line 409, in hf_raise_for_status
    response.raise_for_status()
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/requests/models.py", line 1024, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/facebook/mms-tts-nep/resolve/main/processor_config.json

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/transformers/utils/hub.py", line 403, in cached_file
    resolved_file = hf_hub_download(
                    ^^^^^^^^^^^^^^^^
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/huggingfa

OSError: facebook/mms-tts-nep is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

## Step 1: Install Dependencies

# With Parle TTS
I spent 3 days on this so keeping it here on Appendix.

In [None]:
# Install dependencies
%pip install torch datasets accelerate soundfile
%pip install transformers==4.46.1
# Install the parler-tts package which contains ParlerTTSForConditionalGeneration
%pip install git+https://github.com/huggingface/parler-tts.git

In [38]:
import os
import torch
import numpy as np
from typing import Any, Dict, List, Union
from dataclasses import dataclass
import multiprocessing
import traceback # For printing full tracebacks on error

# --- Set Tokenizers Parallelism ---
# Set this BEFORE initializing tokenizers or trainer to avoid warnings/deadlocks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print("Set TOKENIZERS_PARALLELISM=false")

# --- Import Libraries ---
from datasets import load_dataset, Audio
from transformers import (
    AutoTokenizer,
    AutoFeatureExtractor,
    TrainingArguments,
    Trainer,
    # Needed for checkpoint loading
    trainer_utils
)
from parler_tts import ParlerTTSForConditionalGeneration # Ensure this is installed correctly

# =============================================================================
# === CLASS DEFINITIONS ===
# =============================================================================

# --- Custom Trainer (Corrected Patch Logic) ---
# --- Custom Trainer (DEBUG VERSION - PLEASE USE THIS ONE) ---
class ParlerTTSTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Computes the loss for ParlerTTS. Patches audio_encoder.forward.
        Adds debugging to check the value associated with the audio key.
        """
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        input_features = inputs.get("input_features") # Audio tensor from collator

        if input_ids is None or attention_mask is None or input_features is None:
            missing_keys = [k for k in ["input_ids", "attention_mask", "input_features"] if k not in inputs]
            raise ValueError(f"Missing required inputs in batch. Missing: {missing_keys}. Available: {list(inputs.keys())}")

        original_audio_encoder_forward = model.audio_encoder.forward

        def patched_audio_encoder_forward(*args, **kwargs):
            # --- This is the function being patched ---
            audio_tensor_key = "input_values" # Key expected internally by ParlerTTS -> audio_encoder call
            audio_tensor = None
            remaining_args = list(args)

            # --- DETAILED DEBUGGING ---
            print(f"\n--- DEBUG INSIDE PATCH ---") # <<< LOOK FOR THIS LINE
            print(f"Received args len: {len(args)}, kwargs keys: {list(kwargs.keys())}") # <<< LOOK FOR THIS LINE

            # Check kwargs first
            if audio_tensor_key in kwargs:
                value_in_kwargs = kwargs[audio_tensor_key] # Check value BEFORE popping
                print(f"Found key '{audio_tensor_key}' in kwargs.") # <<< LOOK FOR THIS LINE
                print(f"  Value type: {type(value_in_kwargs)}") # <<< LOOK FOR THIS LINE
                if value_in_kwargs is not None:
                    print(f"  Value is NOT None. Shape: {value_in_kwargs.shape if hasattr(value_in_kwargs, 'shape') else 'N/A'}") # <<< LOOK FOR THIS LINE
                    audio_tensor = kwargs.pop(audio_tensor_key) # Pop and assign only if not None
                else:
                    print(f"  Value IS None. Popping key '{audio_tensor_key}' but keeping audio_tensor=None.") # <<< LOOK FOR THIS LINE
                    kwargs.pop(audio_tensor_key) # Pop key
                    audio_tensor = None # Explicitly ensure it remains None
            # Check args if not found or was None in kwargs
            elif remaining_args:
                value_in_args = remaining_args[0]
                print(f"Key '{audio_tensor_key}' not in kwargs. Trying args[0].") # <<< LOOK FOR THIS LINE
                print(f"  Value type: {type(value_in_args)}") # <<< LOOK FOR THIS LINE
                if value_in_args is not None:
                    print(f"  Value is NOT None. Shape: {value_in_args.shape if hasattr(value_in_args, 'shape') else 'N/A'}") # <<< LOOK FOR THIS LINE
                    audio_tensor = remaining_args.pop(0) # Assign value from args
                else:
                    print(f"  Value IS None (from args[0]). Keeping audio_tensor=None.") # <<< LOOK FOR THIS LINE
                    remaining_args.pop(0) # Remove the None from args
                    audio_tensor = None # Explicitly ensure it remains None
            else:
                # This case should not happen based on previous errors, but good to have
                print(f"Key '{audio_tensor_key}' not found in kwargs or args.") # <<< LOOK FOR THIS LINE
                raise ValueError(f"Audio tensor lookup failed unexpectedly. kwargs: {kwargs.keys()}, args len: {len(args)}")

            # The check that is currently failing
            if audio_tensor is None:
                print(f"FINAL CHECK: audio_tensor is None. Raising error.") # <<< LOOK FOR THIS LINE
                print(f"--------------------------\n")
                raise ValueError("Extracted audio_tensor is None within patch.")
            else:
                 print(f"FINAL CHECK: audio_tensor appears valid. Shape: {audio_tensor.shape}") # <<< LOOK FOR THIS LINE


            # Continue with patch logic if tensor was found
            if 'padding_mask' in kwargs:
                print(f"Removing 'padding_mask' from kwargs.") # <<< LOOK FOR THIS LINE
                del kwargs['padding_mask']

            kwargs['return_dict'] = True
            print(f"Calling original forward with audio tensor shape: {audio_tensor.shape}, remaining_args: {len(remaining_args)}, kwargs: {list(kwargs.keys())}") # <<< LOOK FOR THIS LINE
            print(f"--------------------------\n")

            try:
                 # Call the original method with extracted tensor and modified kwargs
                 output = original_audio_encoder_forward(audio_tensor, *remaining_args, **kwargs)
                 return output
            except Exception as e_inner:
                 # This error handling is for the call to the *original* method
                 print(f"\nERROR inside original_audio_encoder_forward call (within patch): {e_inner}")
                 print(f"DEBUG: audio_tensor type: {type(audio_tensor)}, shape: {audio_tensor.shape if hasattr(audio_tensor, 'shape') else 'N/A'}")
                 print(f"DEBUG: remaining_args: {remaining_args}")
                 print(f"DEBUG: kwargs passed to original: {list(kwargs.keys())}")
                 traceback.print_exc()
                 raise e_inner
        # --- End of patched_audio_encoder_forward ---

        # --- Apply Patch and Compute Loss ---
        try:
            model.audio_encoder.forward = patched_audio_encoder_forward
            # This is the call from Trainer -> Model
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                input_features=input_features, # Audio from collator uses 'input_features'
                return_dict=True
            )

            if not hasattr(outputs, "loss") or outputs.loss is None:
                print(f"DEBUG: Model output keys: {outputs.keys() if isinstance(outputs, dict) else dir(outputs)}")
                raise ValueError("Model output does not contain 'loss'. Check model forward pass and expected outputs.")

            loss = outputs.loss
            return (loss, outputs) if return_outputs else loss

        except Exception as e_outer:
            # This catches errors during the model call or the patch execution itself
            print(f"\nERROR during compute_loss execution (outside patch or in model call): {e_outer}")
            print(f"DEBUG shapes at error point: input_ids={input_ids.shape}, attention_mask={attention_mask.shape}, input_features={input_features.shape if input_features is not None else 'None'}")
            # Don't re-raise immediately if it's the specific None error from the patch, as it's already printed info
            if not isinstance(e_outer, ValueError) or "Extracted audio_tensor is None" not in str(e_outer):
                 traceback.print_exc()
            raise e_outer # Re-raise the exception

        finally:
            # --- Restore Original Method ---
            model.audio_encoder.forward = original_audio_encoder_forward
            # print("DEBUG: Restored original audio_encoder.forward")

# --- End of ParlerTTSTrainer class ---


# --- Custom Data Collator (Robust Version) ---
@dataclass
class TTSDataCollatorWithPadding:
    tokenizer: Any
    feature_extractor: Any
    audio_column_name: str = "file"
    # Standardize audio output key name
    audio_output_key_name: str = "input_features" # Use this key consistently

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor, dict]]]) -> Dict[str, torch.Tensor]:
        if not features: return {}
        if not all(isinstance(f, dict) for f in features): raise ValueError("Features must be dicts.")
        required_keys = ["input_ids", "attention_mask", self.audio_column_name]
        # Check each feature individually for required keys
        for i, f in enumerate(features):
            missing_keys_in_feature = [k for k in required_keys if k not in f]
            if missing_keys_in_feature:
                raise ValueError(f"Feature at index {i} is missing required keys: {missing_keys_in_feature}. Available keys: {list(f.keys())}")


        input_ids = [f["input_ids"] for f in features]
        attention_mask = [f["attention_mask"] for f in features]

        try:
            text_batch = self.tokenizer.pad(
                {"input_ids": input_ids, "attention_mask": attention_mask},
                padding=True, return_tensors="pt",
            )
        except Exception as e:
            print(f"Error padding text batch: {e}"); traceback.print_exc(); raise

        # Check audio column structure more carefully
        problematic_audio_indices = []
        for i, f in enumerate(features):
             audio_data = f.get(self.audio_column_name)
             if not isinstance(audio_data, dict) or 'array' not in audio_data:
                 problematic_audio_indices.append(f"Index {i}: type={type(audio_data)}, keys={audio_data.keys() if isinstance(audio_data, dict) else 'N/A'}")
        if problematic_audio_indices:
             raise ValueError(f"Audio column '{self.audio_column_name}' issues found. Problematic features: {problematic_audio_indices}")


        # Convert to numpy arrays and check for empty/invalid ones
        try:
            audio_arrays = []
            for i, f in enumerate(features):
                raw_array = f[self.audio_column_name]["array"]
                # Ensure it's convertible to a numpy array and handle potential Nones/empties
                if raw_array is None:
                     raise ValueError(f"Audio array at index {i} is None.")
                np_array = np.asarray(raw_array, dtype=np.float32)
                if np_array.ndim == 0 or np_array.size == 0: # Check for scalar or empty array
                    raise ValueError(f"Audio array at index {i} is empty or scalar after conversion. Original type: {type(raw_array)}, Shape: {np_array.shape}")
                audio_arrays.append(np_array)

        except Exception as e:
            print(f"Error processing or converting audio array to numpy at index {i}: {e}"); traceback.print_exc(); raise


        # Process with feature extractor
        try:
            audio_batch = self.feature_extractor(
                audio_arrays, sampling_rate=self.feature_extractor.sampling_rate,
                padding=True, # Let the feature extractor handle padding
                return_tensors="pt"
            )
        except ValueError as e:
            print(f"ValueError during feature extraction: {e}")
            print(f"Input audio array shapes: {[arr.shape for arr in audio_arrays]}")
            print(f"Input audio array dtypes: {[arr.dtype for arr in audio_arrays]}")
            traceback.print_exc(); raise
        except Exception as e:
            print(f"Error processing audio batch with feature extractor: {e}"); traceback.print_exc(); raise

        # Find the correct key ('input_features' or 'input_values') from feature extractor output
        potential_keys = ["input_features", "input_values"]
        audio_input_key = None
        for key in potential_keys:
            if key in audio_batch:
                audio_input_key = key
                break

        if audio_input_key is None:
             raise KeyError(f"Feature extractor output missing expected keys ({potential_keys}). Found keys: {audio_batch.keys()}")

        # Get the final tensor and validate it
        processed_audio_input = audio_batch[audio_input_key]
        if processed_audio_input is None or not isinstance(processed_audio_input, torch.Tensor):
             raise TypeError(f"Collator: Extracted audio ('{audio_input_key}') is not a valid tensor. Type: {type(processed_audio_input)}")
        if processed_audio_input.nelement() == 0: # Check if tensor is empty
             raise ValueError(f"Collator: Extracted audio tensor ('{audio_input_key}') is empty. Shape: {processed_audio_input.shape}")

        # Assemble the final batch
        batch = {
            "input_ids": text_batch["input_ids"],
            "attention_mask": text_batch["attention_mask"],
            # Use the standardized key name for the trainer
            self.audio_output_key_name: processed_audio_input,
        }
        return batch


# =============================================================================
# === MAIN EXECUTION ===
# =============================================================================

def main():
    # --- Set Multiprocessing Start Method ---
    try:
        if multiprocessing.get_start_method(allow_none=True) != 'spawn':
            multiprocessing.set_start_method('spawn', force=True)
            print("Set multiprocessing start method to 'spawn'")
        else:
            print("Multiprocessing start method already 'spawn'.")
    except RuntimeError as e:
        print(f"Multiprocessing start method info: {e}. Continuing...")
    except ValueError as e:
        print(f"Could not set/get multiprocessing start method: {e}. Check environment compatibility.")


    # --- Configuration ---
    model_name = "ai4bharat/indic-parler-tts"
    # Ensure this path is correct for your setup
    base_dir = os.getcwd() # Or specify absolute path: "/path/to/your/project"
    data_folder_name = "voice_data" # Folder containing metadata.csv and audio files
    data_folder_path = os.path.join(base_dir, data_folder_name)
    metadata_filename = "metadata.csv"
    audio_column_name = "file" # Column name in metadata.csv pointing to audio files
    text_column_name = "text"  # Column name in metadata.csv for transcriptions
    output_dir = "./nepali_parler_tts_finetuned" # Where checkpoints and final model are saved
    logs_dir = "./logs" # Tensorboard logs
    # Standardized key for audio data after collation (must match collator's audio_output_key_name)
    collated_audio_key = "input_features"

    # --- 1. Load Model, Tokenizer, and Feature Extractor ---
    print(f"Loading model and components from '{model_name}'...")
    # Consider low_cpu_mem_usage=True if loading is slow/memory intensive
    model = ParlerTTSForConditionalGeneration.from_pretrained(model_name) #, low_cpu_mem_usage=True)

    # Load T5 tokenizer associated with the text encoder
    t5_tokenizer_path = model.config.text_encoder._name_or_path
    print(f"Loading T5 tokenizer from: {t5_tokenizer_path}")
    t5_tokenizer = AutoTokenizer.from_pretrained(t5_tokenizer_path)

    # Load feature extractor (likely Encodec for ParlerTTS)
    print(f"Loading Feature Extractor from: {model_name}")
    # trust_remote_code=True might be needed for custom feature extractors like Encodec
    try:
        feature_extractor = AutoFeatureExtractor.from_pretrained(model_name, trust_remote_code=True)
    except Exception as e:
         print(f"Error loading feature extractor: {e}. Ensure 'trust_remote_code=True' is appropriate or check model compatibility.")
         raise e

    # Load the main tokenizer (used for saving configs, etc.)
    print(f"Loading Tokenizer (for saving) from: {model_name}")
    tokenizer_for_saving = AutoTokenizer.from_pretrained(model_name)

    TARGET_SAMPLING_RATE = feature_extractor.sampling_rate
    print(f"Target sampling rate set by Feature Extractor: {TARGET_SAMPLING_RATE}")

    # --- 2. Load and Prepare the Dataset ---
    print(f"Loading dataset metadata '{metadata_filename}' from directory: {data_folder_path}")
    metadata_full_path = os.path.join(data_folder_path, metadata_filename)
    if not os.path.exists(metadata_full_path): raise FileNotFoundError(f"Metadata file not found: {metadata_full_path}")
    if not os.path.isdir(data_folder_path): raise NotADirectoryError(f"Data directory not found or not a directory: {data_folder_path}")

    # Load dataset using 'csv' type, specifying data_dir so audio paths are relative to it
    try:
        dataset = load_dataset(
            "csv", data_files={"train": metadata_full_path}, data_dir=data_folder_path, keep_in_memory=False
        )
    except Exception as e:
        print(f"Error loading dataset from {metadata_full_path} with data_dir {data_folder_path}: {e}")
        print("Check if the CSV format is correct and audio files exist relative to the data directory.")
        traceback.print_exc()
        raise e

    split_name = "train" # We specified {"train": ...}
    print(f"Using dataset split: '{split_name}'")
    print(f"Original dataset features: {dataset[split_name].features}")


    print(f"Casting audio column '{audio_column_name}' to Audio feature with target sampling rate {TARGET_SAMPLING_RATE}...")
    try:
        dataset = dataset.cast_column(audio_column_name, Audio(sampling_rate=TARGET_SAMPLING_RATE))
    except ValueError as e:
         print(f"Error casting audio column '{audio_column_name}': {e}")
         print("Ensure the column name matches the CSV header and contains relative paths to valid audio files.")
         traceback.print_exc()
         raise e
    except Exception as e:
        print(f"Unexpected error casting audio column: {e}"); traceback.print_exc(); raise

    print(f"Dataset features after casting: {dataset[split_name].features}")

    # --- 3. Preprocessing Function ---
    def preprocess_function(examples):
        # Defensive checks for column existence in the batch
        if text_column_name not in examples: raise KeyError(f"Text column '{text_column_name}' not found in examples batch. Available: {list(examples.keys())}")
        if audio_column_name not in examples: raise KeyError(f"Audio column '{audio_column_name}' not found in examples batch. Available: {list(examples.keys())}")

        texts = examples[text_column_name]
        # Ensure texts are strings, handle None or other types gracefully
        if not isinstance(texts, list): texts = [texts] # Handle non-batched case if it occurs
        processed_texts = []
        for i, t in enumerate(texts):
            if isinstance(t, str):
                processed_texts.append(t)
            elif t is None:
                print(f"Warning: Found None in text column at batch index {i}. Replacing with empty string.")
                processed_texts.append("")
            else:
                print(f"Warning: Non-string type ({type(t)}) found in text column at batch index {i}. Attempting conversion.")
                processed_texts.append(str(t))

        # Tokenize text - padding=False here, collator will pad the batch later
        inputs = t5_tokenizer(
            processed_texts, padding=False, truncation=True, max_length=512, return_tensors=None
        )

        # Prepare batch dictionary for the next step (mapping)
        # Keep the raw audio data associated with the audio_column_name for the collator
        batch = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            audio_column_name: examples[audio_column_name] # Pass the audio data structure through
        }
        return batch

    print("Applying preprocessing function to the dataset...")
    # Define columns to keep after mapping (those needed by collator)
    columns_to_keep = ["input_ids", "attention_mask", audio_column_name]
    # Determine columns to remove (all others)
    columns_to_remove = [col for col in dataset[split_name].column_names if col not in columns_to_keep]
    print(f"Columns to remove during map: {columns_to_remove}")

    # Limit num_proc to avoid excessive RAM usage during mapping
    # Adjust based on your system's RAM (8GB is tight, start with 1 or 2)
    map_num_proc = max(1, os.cpu_count() // 4 if os.cpu_count() else 1)
    print(f"Using {map_num_proc} processes for dataset mapping.")

    try:
        processed_dataset = dataset[split_name].map(
            preprocess_function,
            batched=True,
            batch_size=100, # Reduce if mapping causes OOM errors
            num_proc=map_num_proc,
            remove_columns=columns_to_remove,
            keep_in_memory=False # Safer for low RAM
        )
    except Exception as e:
        print(f"Error during dataset mapping: {e}")
        print("Check the preprocess_function for errors or reduce batch_size/num_proc if it's a memory issue.")
        traceback.print_exc()
        raise e

    print(f"Processed dataset columns: {processed_dataset.column_names}")
    print(f"Number of samples in processed dataset: {len(processed_dataset)}")
    if not processed_dataset or len(processed_dataset) == 0: raise ValueError("Processed dataset is empty after mapping.")
    # Verify required columns are present after mapping
    expected_cols_after_map = ["input_ids", "attention_mask", audio_column_name]
    if not all(col in processed_dataset.column_names for col in expected_cols_after_map):
        raise ValueError(f"Processed dataset missing expected columns after mapping. Expected: {expected_cols_after_map}, Got: {processed_dataset.column_names}")


    # --- 4. Data Collator ---
    print("Initializing data collator...")
    data_collator = TTSDataCollatorWithPadding(
        tokenizer=t5_tokenizer,
        feature_extractor=feature_extractor,
        audio_column_name=audio_column_name,
        audio_output_key_name=collated_audio_key # Must match key expected by trainer
    )

    # --- 5. Training Arguments (CPU / Low RAM specific) ---
    print("Configuring Training Arguments for CPU and low RAM...")
    force_cpu = True # Keep forcing CPU as requested on M2 Mac for stability/testing
    print(f"Forcing CPU usage: {force_cpu}")

    # Reduce epochs initially to test the full pipeline quickly
    num_epochs_test = 3 # Use a small number like 3-5 for initial tests
    print(f"Setting num_train_epochs to {num_epochs_test} for testing.")

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs_test, # Start with fewer epochs
        per_device_train_batch_size=1,    # Essential for 8GB RAM
        gradient_accumulation_steps=16,   # Effective batch size = 1 * 16 = 16. Increase if OOM during backward pass, decrease if forward pass fails.
        gradient_checkpointing=True,      # Saves memory by recomputing activations
        learning_rate=2e-5,               # Standard fine-tuning LR
        warmup_steps=50,                  # Number of steps for learning rate warmup
        max_grad_norm=1.0,                # Gradient clipping
        logging_dir=logs_dir,
        logging_strategy="steps",
        logging_steps=10,                 # Log loss frequently
        evaluation_strategy="no",         # Disable evaluation
        save_strategy="epoch",            # Save checkpoint every epoch
        save_total_limit=2,               # Keep last 2 checkpoints + final model
        report_to="tensorboard",          # Log to tensorboard (can set to "none")
        remove_unused_columns=False,      # Crucial: Keep columns needed by model/collator
        load_best_model_at_end=False,     # Requires evaluation
        push_to_hub=False,

        # --- Explicit CPU/MPS/CUDA Control ---
        no_cuda=force_cpu,                # Explicitly disable CUDA
        use_mps_device=False,             # Explicitly disable MPS (Apple Silicon GPU) if forcing CPU or facing issues
        fp16=False,                       # FP16 not supported/beneficial on CPU
        bf16=False,                       # BF16 not supported/beneficial on CPU

        # Dataloader settings for stability (especially in notebooks)
        dataloader_num_workers=0,         # MUST be 0 for 'spawn' or notebook usage
        dataloader_pin_memory=False,      # No GPU pinning needed/possible on CPU
    )

    # --- Sanity Check Device Configuration ---
    print("--- Device Configuration Check ---")
    if force_cpu:
        print("TrainingArguments: Forcing CPU (no_cuda=True, use_mps_device=False).")
    else:
        if training_args.use_mps_device:
             if torch.backends.mps.is_available(): print("TrainingArguments: MPS enabled and available.")
             else: print("TrainingArguments: MPS enabled BUT NOT AVAILABLE on this system.")
        elif not training_args.no_cuda:
             if torch.cuda.is_available(): print("TrainingArguments: CUDA enabled and available.")
             else: print("TrainingArguments: CUDA enabled BUT NOT AVAILABLE on this system.")
        else:
            print("TrainingArguments: Configured for CPU (either no_cuda=True or use_mps_device=False).")
    print(f"PyTorch detected default device: {torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')}")
    print("---------------------------------")


    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(logs_dir, exist_ok=True)

    # --- 6. Trainer ---
    # Initialize Trainer ONCE with all components
    print("Initializing Trainer...")
    trainer = ParlerTTSTrainer( # Use the corrected custom trainer class
        model=model,
        args=training_args,
        train_dataset=processed_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer_for_saving, # Pass tokenizer for saving alongside model
    )
    print(f"Trainer initialized. Device determined by TrainingArguments: {trainer.args.device}")

    # --- 7. Start Training ---
    print(f"\n{'='*20} STARTING TRAINING {'='*20}")
    try:
        last_checkpoint = None
        # Check for existing checkpoints in the output directory
        if os.path.isdir(training_args.output_dir):
             print(f"Checking for checkpoints in {training_args.output_dir}...")
             last_checkpoint = trainer_utils.get_last_checkpoint(training_args.output_dir)
             if last_checkpoint:
                 print(f"Resuming training from checkpoint: {last_checkpoint}")
             else:
                 print("No checkpoint found, starting training from scratch.")
        else:
             print("Output directory does not exist, starting training from scratch.")

        # Start or resume training
        train_result = trainer.train(resume_from_checkpoint=last_checkpoint)

        # Log and save metrics after successful training completion
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state() # Save optimizer state etc.
        print(f"\n{'='*20} TRAINING FINISHED SUCCESSFULLY {'='*20}")
        print(f"Metrics: {metrics}")

    except Exception as e:
        print(f"\n{'='*20} TRAINING FAILED {'='*20}")
        print(f"Error: {e}")
        print(f"{'='*50}\nFull Traceback:")
        traceback.print_exc()
        print(f"{'='*50}")
        # Even if training fails, try saving the current state if possible
        try:
            print("Attempting to save trainer state after error...")
            trainer.save_state()
            print("Trainer state saved.")
        except Exception as save_e:
            print(f"Could not save trainer state after error: {save_e}")
        # Exit or re-raise if preferred
        # return # Exit the main function cleanly after failure

    # --- 8. Save the Fine-Tuned Model Components ---
    # This will save the model state at the end of successful training,
    # or the last successfully saved checkpoint state if training was interrupted.
    print(f"\nSaving final model components to {output_dir}...")
    try:
        # Save model using trainer's method (handles sharding etc. if applicable)
        trainer.save_model(output_dir)
        print(f"Model saved to {output_dir}")

        # Save tokenizer and feature extractor explicitly for completeness
        if hasattr(trainer, 'tokenizer') and trainer.tokenizer is not None:
            trainer.tokenizer.save_pretrained(output_dir)
            print("Tokenizer saved.")
        else:
            # Fallback to saving the tokenizer loaded initially if trainer doesn't have it
            if 'tokenizer_for_saving' in locals() and tokenizer_for_saving is not None:
                 tokenizer_for_saving.save_pretrained(output_dir)
                 print("Tokenizer (initial) saved.")
            else:
                 print("Warning: Trainer does not have a tokenizer attribute, and initial tokenizer not found.")

        # Check if feature_extractor is defined before saving
        if 'feature_extractor' in locals() and feature_extractor is not None:
             feature_extractor.save_pretrained(output_dir)
             print("Feature extractor saved.")
        else:
             print("Warning: Feature extractor not found or is None, cannot save.")

        print("Model components saved successfully.")

    except Exception as e:
        print(f"Error saving final model components: {e}")
        traceback.print_exc()

    print("\nFine-tuning script finished.")


if __name__ == "__main__":
    main()

Set TOKENIZERS_PARALLELISM=false
Multiprocessing start method already 'spawn'.
Loading model and components from 'ai4bharat/indic-parler-tts'...


Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "_name_or_path": "google/flan-t5-large",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32128
}

Config of the audio_encoder: <class 'transformers.models.dac.modelin

Loading T5 tokenizer from: google/flan-t5-large
Loading Feature Extractor from: ai4bharat/indic-parler-tts
Loading Tokenizer (for saving) from: ai4bharat/indic-parler-tts
Target sampling rate set by Feature Extractor: 44100
Loading dataset metadata 'metadata.csv' from directory: /Users/jeevanbhatta/Downloads/voice-classification/nepali_tts/voice_data


Generating train split: 132 examples [00:00, 248.74 examples/s]


Using dataset split: 'train'
Original dataset features: {'file': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None)}
Casting audio column 'file' to Audio feature with target sampling rate 44100...
Dataset features after casting: {'file': Audio(sampling_rate=44100, mono=True, decode=True, id=None), 'text': Value(dtype='string', id=None)}
Applying preprocessing function to the dataset...
Columns to remove during map: ['text']
Using 2 processes for dataset mapping.


Map (num_proc=2):   0%|          | 0/132 [00:01<?, ? examples/s]
multiprocess.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/multiprocess/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
                    ^^^^^^^^^^^^^^^^^^^
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/datasets/utils/py_utils.py", line 680, in _write_generator_to_queue
    for i, result in enumerate(func(**kwargs)):
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3516, in _map_single
    for i, batch in iter_outputs(shard_iterable):
  File "/Users/jeevanbhatta/Downloads/voice-classification/venv/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3466, in iter_outputs
    yield i, apply_function(example, i, offset=offset)
             ^^^^^^^^^^^^^^^^^^^^^

Error during dataset mapping: [Errno 2] No such file or directory: 'sample_000.wav'
Check the preprocess_function for errors or reduce batch_size/num_proc if it's a memory issue.


FileNotFoundError: [Errno 2] No such file or directory: 'sample_000.wav'

In [12]:
import torch
from transformers import AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf

# Load the fine-tuned model and tokenizer
model_path = "./nepali_parler_tts_finetuned"
model = ParlerTTSForConditionalGeneration.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path)
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)

# Nepali text input in Devnagari script
nepali_text = "यो सुन्दर बिहान हो।"
description = "A male user's voice speaking clearly and naturally." # Adjust the description to match your desired voice style

# Tokenize input
description_input_ids = description_tokenizer(description, return_tensors="pt").to(model.device)
prompt_input_ids = tokenizer(nepali_text, return_tensors="pt").to(model.device)

# Generate speech
generation = model.generate(
    input_ids=description_input_ids.input_ids,
    attention_mask=description_input_ids.attention_mask,
    prompt_input_ids=prompt_input_ids.input_ids,
    prompt_attention_mask=prompt_input_ids.attention_mask,
)

# Save the generated audio
audio_arr = generation.cpu().numpy().squeeze()
sampling_rate = model.config.sampling_rate
sf.write("nepali_output.wav", audio_arr, sampling_rate)

print("Generated audio saved as nepali_output.wav")

Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "_name_or_path": "google/flan-t5-large",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32128
}

Config of the audio_encoder: <class 'transformers.models.dac.modelin

Generated audio saved as nepali_output.wav


In [None]:
import torch
from transformers import AutoTokenizer
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf
import os
import traceback # Import traceback

# --- Configuration ---
model_path = "./nepali_parler_tts_finetuned"
output_filename = "nepali_output_cpu.wav"

# --- Explicitly set device to CPU ---
device = torch.device("cpu")
print(f"Using device: {device}")

# --- Load Model ---
print(f"Loading model from: {model_path}")
if not os.path.isdir(model_path):
    raise FileNotFoundError(f"Model directory not found at: {model_path}.")

model = ParlerTTSForConditionalGeneration.from_pretrained(model_path).to(device)
model.eval()

# --- Load the Correct (T5) Tokenizer ---
try:
    t5_tokenizer_path = model.config.text_encoder._name_or_path
    print(f"Loading T5 tokenizer from: {t5_tokenizer_path}")
    if not os.path.isdir(t5_tokenizer_path):
         resolved_t5_path = os.path.join(model_path, t5_tokenizer_path)
         if os.path.isdir(resolved_t5_path): t5_tokenizer_path = resolved_t5_path
         else: t5_tokenizer_path = model_path # Fallback
    tokenizer = AutoTokenizer.from_pretrained(t5_tokenizer_path)
except Exception as e:
    print(f"Error loading the T5 tokenizer: {e}"); raise

# --- Input Text and Description ---
nepali_text = "आज म बजार जान्छु। मेरो नाम राम हो। उनको घर पहाडमा छ। हामीले नेपाली भाषा सिक्नुपर्छ। खाना तयार भयो। पानी परिरहेको छ।"
description = "Jeevan speaking in Nepali."
print(f"Text to synthesize: {nepali_text}")
print(f"Voice description: {description}")

# --- Tokenize Inputs using the T5 Tokenizer ---
# Tokenize the main text prompt (text to be synthesized)
prompt_inputs = tokenizer(
    nepali_text, return_tensors="pt", padding=True, truncation=True, max_length=512
).to(device)

# Tokenize the description prompt (voice style guidance)
description_inputs = tokenizer(
    description, return_tensors="pt", padding=True, truncation=True, max_length=512
).to(device)


# --- Generate Speech ---
print("Generating audio... (This might take a while on CPU)")
try:
    # --- CORRECTED ARGUMENT MAPPING ---
    # For this ParlerTTS variant, description goes to input_ids,
    # and the text to be synthesized goes to prompt_input_ids.
    with torch.no_grad():
        generation = model.generate(
            input_ids=description_inputs.input_ids,                 # Description here
            attention_mask=description_inputs.attention_mask,       # Description mask here
            prompt_input_ids=prompt_inputs.input_ids,               # Text to synthesize here
            prompt_attention_mask=prompt_inputs.attention_mask,     # Text mask here
            # Other generation parameters
            do_sample=True,
            temperature=0.8,
            top_k=50,
            top_p=0.9,
            repetition_penalty=1.2,
            max_new_tokens=None # Let model decide length
        ).cpu()
    # --- END CORRECTION ---

except Exception as e:
    print(f"Error during generation: {e}")
    print("Check model compatibility, arguments, and memory limits.")
    traceback.print_exc()
    raise

# --- Save the Generated Audio ---
audio_arr = generation.numpy().squeeze()
try:
    if hasattr(model.config, "sampling_rate"):
        sampling_rate = model.config.sampling_rate
    elif hasattr(model.config, "audio_encoder") and hasattr(model.config.audio_encoder, "sampling_rate"):
        sampling_rate = model.config.audio_encoder.sampling_rate
    else:
        print("Warning: Sampling rate not found in config. Using default 24000.")
        sampling_rate = 24000

    print(f"Saving audio with sampling rate: {sampling_rate}")
    sf.write(output_filename, audio_arr, sampling_rate)
    print(f"Generated audio saved as {output_filename}")
except Exception as e:
    print(f"Error saving audio file: {e}"); raise

Using device: cpu
Loading model from: ./nepali_parler_tts_finetuned


Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "_name_or_path": "google/flan-t5-large",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32128
}

Config of the audio_encoder: <class 'transformers.models.dac.modelin

Loading T5 tokenizer from: google/flan-t5-large
Text to synthesize: यो सुन्दर बिहान हो।
Voice description: A standard female voice speaking clearly in Nepali.
Generating audio... (This might take a while on CPU)
Saving audio with sampling rate: 44100
Generated audio saved as nepali_output_cpu.wav


## Step 2: Record Voice Samples

Use the interface below to record yourself reading the Nepali sentences. Aim for 50-100 clear recordings in a quiet environment.
Recordings will be saved in the `voice_data` directory along with a `metadata.csv` file.

In [None]:
# Create Gradio interface for recording/uploading samples

import os
import gradio as gr
import pandas as pd
import shutil
import traceback
import numpy as np
import soundfile as sf

# --- Configuration ---
VOICE_DATA_DIR = "voice_data"
METADATA_FILE = os.path.join(VOICE_DATA_DIR, "metadata.csv")
SENTENCES_FILE = "nepali_sentences.txt"
TARGET_SAMPLE_RATE = 24000 # Parler-TTS expects 24kHz

# --- Helper Functions ---
os.makedirs(VOICE_DATA_DIR, exist_ok=True)

def load_sentences_from_file(file_path=SENTENCES_FILE):
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            sentences = [line.strip() for line in f if line.strip()]
        if not sentences:
            print(f"Warning: No sentences found in {file_path}. Using default.")
            return ["नमस्ते, मेरो नाम जीवन हो।", "म नेपाली भाषामा बोल्छु।"]
        print(f"Loaded {len(sentences)} sentences from {file_path}")
        return sentences
    except FileNotFoundError:
        print(f"Warning: Sentences file '{file_path}' not found. Using default.")
        return ["नमस्ते, मेरो नाम जीवन हो।", "म नेपाली भाषामा बोल्छु।"]

nepali_sentences = load_sentences_from_file()

# Load existing metadata or create new DataFrame
if os.path.exists(METADATA_FILE):
    try:
        metadata_df = pd.read_csv(METADATA_FILE)
        # Find the next available index based on existing filenames like sample_XXX.wav
        existing_indices = metadata_df['file'].str.extract(r'sample_(\d+).wav').astype(int).max()[0]
        current_idx = existing_indices + 1 if pd.notna(existing_indices) else 0
    except Exception as e:
        print(f"Error reading metadata file {METADATA_FILE}: {e}. Starting from index 0.")
        metadata_df = pd.DataFrame(columns=["file", "text"])
        current_idx = 0
else:
    metadata_df = pd.DataFrame(columns=["file", "text"])
    current_idx = 0

print(f"Starting sample index: {current_idx}")

def process_and_save_audio(audio_input, text):
    global current_idx, metadata_df
    
    if audio_input is None:
        return "Please record or upload audio first.", current_idx, nepali_sentences[current_idx % len(nepali_sentences)]
    
    # Gradio provides audio as a tuple (sample_rate, numpy_array) or filepath string
    if isinstance(audio_input, tuple):
        sample_rate, audio_data = audio_input
        source_info = "microphone recording"
    elif isinstance(audio_input, str) and os.path.exists(audio_input):
        try:
            audio_data, sample_rate = sf.read(audio_input, dtype='float32')
            source_info = f"uploaded file ({os.path.basename(audio_input)})"
        except Exception as e:
             return f"Error reading uploaded file: {e}", current_idx, nepali_sentences[current_idx % len(nepali_sentences)]
        # Clean up temp file if it was an upload
        if audio_input.startswith(gr.processing_utils.TEMP_DIR):
             os.remove(audio_input)
    else:
        return "Invalid audio input.", current_idx, nepali_sentences[current_idx % len(nepali_sentences)]

    try:
        print(f"Processing audio from {source_info} with original SR: {sample_rate}")
        
        # Ensure mono
        if audio_data.ndim > 1:
            audio_data = np.mean(audio_data, axis=1)
            
        # Resample if necessary (using librosa for potentially better quality)
        if sample_rate != TARGET_SAMPLE_RATE:
            import librosa # Import here to avoid making it a hard dependency if not needed
            audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE)
            print(f"Resampled audio to {TARGET_SAMPLE_RATE} Hz")
            sample_rate = TARGET_SAMPLE_RATE
        
        # Save the processed audio file
        output_filename = f"sample_{current_idx:03d}.wav"
        output_filepath = os.path.join(VOICE_DATA_DIR, output_filename)
        sf.write(output_filepath, audio_data, sample_rate)
        print(f"Saved processed audio to {output_filepath}")
        
        # Update metadata DataFrame
        new_row = pd.DataFrame([{"file": output_filename, "text": text}])
        metadata_df = pd.concat([metadata_df, new_row], ignore_index=True)
        
        # Save metadata to CSV
        metadata_df.to_csv(METADATA_FILE, index=False, encoding='utf-8')
        
        # Increment index for the next recording
        current_idx += 1
        next_sentence_idx = current_idx % len(nepali_sentences)
        
        return f"✅ Saved sample {current_idx-1}", current_idx, nepali_sentences[next_sentence_idx]
        
    except Exception as e:
        error_details = traceback.format_exc()
        print(f"ERROR processing/saving audio: {error_details}")
        return f"❌ Error: {str(e)}", current_idx, nepali_sentences[current_idx % len(nepali_sentences)]

# --- Gradio Interface Definition ---
with gr.Blocks() as demo:
    gr.Markdown("# Record Your Nepali Voice Samples")
    gr.Markdown(f"Read the sentence below and record/upload your voice. Audio will be saved to '{VOICE_DATA_DIR}' as {TARGET_SAMPLE_RATE}Hz WAV.")
    
    with gr.Row():
        with gr.Column(scale=3):
            text_to_read = gr.Textbox(
                value=nepali_sentences[current_idx % len(nepali_sentences)], 
                label="Sentence to Read",
                lines=3,
                interactive=True # Allow user to potentially correct/change text
            )
            
            # Use gr.Audio which handles both recording and upload
            # Type 'numpy' returns (sr, data), 'filepath' returns path to temp file
            audio_input = gr.Audio(
                sources=["microphone", "upload"], 
                type="numpy", # Get raw data and sample rate
                label="Record or Upload Audio (.wav, .mp3, etc.)",
                format="wav" # Preferred format for saving
            )
            
            save_btn = gr.Button("Save Sample", variant="primary")
            status = gr.Textbox(value="Ready to record/upload", label="Status", interactive=False)
            count = gr.Number(value=current_idx, label="Samples Saved", interactive=False)
            
        with gr.Column(scale=2):
            gr.Markdown(""""
            ### Recording Tips:
            1.  **Quiet Environment:** Minimize background noise.
            2.  **Clear Speech:** Speak naturally and clearly.
            3.  **Mic Distance:** Maintain a consistent distance (e.g., 15-20cm).
            4.  **Consistent Tone:** Avoid large variations in volume/pitch.
            5.  **Target:** Aim for 50-100+ high-quality samples. """
                )
            gr.Markdown(f"**Sentences Source:** `{SENTENCES_FILE}`")
            gr.Markdown(f"**Output Directory:** `{VOICE_DATA_DIR}`")
            gr.Markdown(f"**Metadata File:** `{METADATA_FILE}`")

    # Connect button click to the processing function
    save_btn.click(
        process_and_save_audio, 
        inputs=[audio_input, text_to_read], 
        outputs=[status, count, text_to_read]
    )

# Launch the interface
demo.launch(share=True, debug=True) # Share=True provides a public link, Debug=True shows more logs

## Step 3: Prepare Dataset for Fine-Tuning

This step involves:
1.  Loading the `metadata.csv`.
2.  Defining a PyTorch `Dataset` class.
3.  Inside the dataset, for each audio file:
    *   Load and resample the audio.
    *   Pad or truncate audio to a fixed length.
    *   Use the DAC codec to convert audio waveforms into discrete audio tokens.
    *   Cache these tokens to disk for faster subsequent loading.
4.  Tokenize the corresponding text prompts.
5.  Create a `DataLoader` to batch the data for training.

## Step 4: Fine-Tune the Parler TTS Model

This section defines and runs the fine-tuning loop:
1.  Load the pre-trained `ai4bharat/indic-parler-tts` model.
2.  Set up the optimizer (AdamW) and learning rate scheduler.
3.  Define a voice description prompt.
4.  Iterate through epochs and batches:
    *   Get text prompts, attention masks, and pre-computed audio tokens from the dataloader.
    *   Tokenize the voice description.
    *   Pass all inputs to the model (audio tokens are used as `labels`).
    *   Calculate the loss.
    *   Perform backpropagation and update model weights.
5.  Save model checkpoints periodically and the final model.

## Step 5: Generate Speech with the Fine-Tuned Voice

Use the fine-tuned model saved in the previous step to synthesize speech for new Nepali text prompts.