# TTS Model Distillation: Setup & Data Prep

This notebook initializes the workspace and prepares the dataset from the parent project's `output/` directory.

## Objectives
1.  Verify GPU availability (RTX 6000).
2.  Define paths and configuration.
3.  Ingest audio/transcript pairs from `../output`.
4.  Format data for training (LJSpeech format).
5.  Visualize audio quality (Mel Spectrograms).

In [None]:
import os
import sys
import torch
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm

# Add src to path
sys.path.append('../src')

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Paths
PARENT_OUTPUT_DIR = Path("../output") # Where the ElevenLabs audio lives
DATASET_DIR = Path("../datasets/raw")
DATASET_DIR.mkdir(parents=True, exist_ok=True)

# Audio Settings
TARGET_SR = 22050 # Standard for VITS/XTTS usually, or 24000
TRIM_SILENCE = True

## 2. Data Ingestion
We need to find pairs of `.mp3` (audio) and `.md` (transcript) files. 
*Note: The parent project currently outputs `transcript.md` and `audio.mp3`. We need to match them.*

In [None]:
def find_data_pairs(source_dir):
    data = []
    source_path = Path(source_dir)
    
    # List all mp3 files
    audio_files = list(source_path.glob("*.mp3"))
    
    for audio_file in audio_files:
        # Assuming naming convention: name.mp3 and name_transcript.md
        # Or we might need to parse the transcript file if it's separate.
        # Let's check for a corresponding transcript file.
        base_name = audio_file.stem
        
        # Try to find the transcript. 
        # In the parent script: output = f"output/{base_name}.mp3"
        # And transcript was: f"output/{base_name}_transcript.md"
        
        transcript_file = source_path / f"{base_name}_transcript.md"
        
        if transcript_file.exists():
            with open(transcript_file, 'r', encoding='utf-8') as f:
                text = f.read()
            
            data.append({
                'audio_path': str(audio_file),
                'text': text,
                'id': base_name
            })
        else:
            print(f"Warning: No transcript found for {audio_file.name}")
            
    return pd.DataFrame(data)

df = find_data_pairs(PARENT_OUTPUT_DIR)
print(f"Found {len(df)} pairs.")
df.head()

## 3. Preprocessing & Formatting (LJSpeech Style)
Format: `id|text` (Audio files usually in a `wavs` folder)

In [None]:
def preprocess_dataset(df, output_dir, target_sr=22050):
    wavs_dir = output_dir / "wavs"
    wavs_dir.mkdir(parents=True, exist_ok=True)
    
    metadata = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        # Load Audio
        y, sr = librosa.load(row['audio_path'], sr=target_sr)
        
        # Trim Silence
        if TRIM_SILENCE:
            y, _ = librosa.effects.trim(y, top_db=20)
        
        # Save as WAV
        out_name = f"{row['id']}.wav"
        out_path = wavs_dir / out_name
        
        import soundfile as sf
        sf.write(out_path, y, target_sr)
        
        # Clean Text (Minimal for now, can add phonemization later)
        clean_text = row['text'].strip()
        
        metadata.append(f"{row['id']}|{clean_text}")
        
    # Save Metadata
    with open(output_dir / "metadata.csv", 'w', encoding='utf-8') as f:
        f.write("\n".join(metadata))
        
    print(f"Processed dataset saved to {output_dir}")

if len(df) > 0:
    preprocess_dataset(df, DATASET_DIR, target_sr=TARGET_SR)

## 4. Visualization

In [None]:
if len(df) > 0:
    # Visualize the last processed file
    example_wav = list((DATASET_DIR / "wavs").glob("*.wav"))[0]
    y, sr = librosa.load(example_wav)
    
    plt.figure(figsize=(10, 4))
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
    S_dB = librosa.power_to_db(S, ref=np.max)
    librosa.display.specshow(S_dB, x_axis='time', y_axis='mel', sr=sr, fmax=8000)
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel-frequency spectrogram')
    plt.tight_layout()
    plt.show()