# WhisperTune: Synthetic Data Generation

This notebook demonstrates how to generate synthetic speech data for fine-tuning Whisper on low-resource languages.

## Why this Notebook?

Fine-tuning Whisper for low-resource languages faces a common challenge: limited high-quality audio-text paired data. This notebook provides a solution by:

1. Using Meta's MMS-TTS models to generate synthetic speech
2. Applying realistic augmentations to improve robustness
3. Creating large-scale training data for low-resource languages

The synthetic data can be used to pre-train or supplement existing datasets for Whisper fine-tuning.

## Note on Implementation

While it's generally good practice to store classes and functions in separate Python modules, this notebook intentionally keeps all code in cells because:

1. It's designed for Google Colab usage where direct file imports can be cumbersome
2. Makes it easier to share and run without setting up a full development environment
3. Allows for interactive experimentation and modification of the code

For production use, you may want to refactor the code into proper Python modules.

## How to Use this Notebook

### Setup
1. Mount your Google Drive using the provided cell
2. Import required dependencies (will be installed automatically in Colab)

### Configuration
1. Modify the `config` dictionary to set:
   - TTS model (default: facebook/mms-tts-tgk for Tajik)
   - Augmentation parameters
   - Output paths

### Data Generation
1. Prepare your text corpus in a plain text file (one sentence per line)
2. Set the `output_dir` and `json_output` paths
3. Run the data generation cell with desired parameters:
   - `batch_size`: Number of parallel generations
   - `sample_size`: How many samples to generate
   - `random_seed`: For reproducibility

### Output
The notebook will generate:
- WAV audio files in the specified output directory
- A JSONL metadata file with paths and transcripts

In [None]:
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Tuple
from tqdm import tqdm
import librosa
import soundfile as sf
from transformers import VitsModel, AutoTokenizer
import logging
import os
import json
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Mounted at /content/drive


In [None]:
class SyntheticDataGenerator:
    def __init__(self, config: Dict):
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self._initialize_model()
        self.sampling_rate = self.model.config.sampling_rate

    def _initialize_model(self):
        """Initialize VITS model and tokenizer from HuggingFace."""
        try:
            model_name = self.config['tts'].get('model', 'facebook/mms-tts-tgk')
            self.model = VitsModel.from_pretrained(model_name).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            logger.info(f"Initialized model {model_name} on {self.device}")
        except Exception as e:
            logger.error(f"Error initializing VITS model: {e}")
            raise

    def sample_text_corpus(
        self,
        input_file: str,
        sample_size: int,
        output_file: str = None,
        random_seed: int = None
    ) -> List[str]:
        """
        Sample random lines from the input text corpus.

        Args:
            input_file (str): Path to input text file
            sample_size (int): Number of lines to sample
            output_file (str): Path to save sampled text (optional)
            random_seed (int): Random seed for reproducibility

        Returns:
            List[str]: Sampled text lines
        """
        if random_seed is not None:
            random.seed(random_seed)
            np.random.seed(random_seed)

        # Read all lines
        with open(input_file, 'r', encoding='utf-8') as f:
            texts = [line.strip() for line in f if line.strip()]

        total_lines = len(texts)
        logger.info(f"Total lines in corpus: {total_lines}")

        if sample_size >= total_lines:
            logger.warning(f"Sample size {sample_size} is >= total lines {total_lines}. Using entire corpus.")
            sampled_texts = texts
        else:
            # Random sampling without replacement
            sampled_indices = random.sample(range(total_lines), sample_size)
            sampled_texts = [texts[i] for i in sampled_indices]
            logger.info(f"Sampled {sample_size} lines from corpus")

        # Save sampled text if output file is specified
        if output_file:
            output_file = str(Path(output_file))
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            with open(output_file, 'w', encoding='utf-8') as f:
                for text in sampled_texts:
                    f.write(text + '\n')
            logger.info(f"Saved sampled text to {output_file}")

        return sampled_texts

    def generate_audio(self, text: str, output_path: str) -> str:
        """Generate synthetic audio from text."""
        try:
            # Tokenize and generate audio
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            with torch.no_grad():
                waveform = self.model(**inputs).waveform.squeeze()

            # Move to CPU and convert to numpy
            wav = waveform.cpu().numpy()

            # Apply augmentation if enabled
            if self.config['augmentation']['enabled']:
                wav = self._augment_audio(wav)

            # Save audio
            sf.write(output_path, wav, self.sampling_rate)
            return output_path

        except Exception as e:
            logger.error(f"Error generating audio for text: {text}")
            logger.error(f"Error: {e}")
            return None

    def _augment_audio(self, audio: np.ndarray) -> np.ndarray:
        """Apply various augmentation techniques to the audio."""
        # Add noise if configured
        if self.config['augmentation'].get('noise_factor'):
            noise = np.random.randn(len(audio))
            audio = audio + self.config['augmentation']['noise_factor'] * noise

        # Apply speed perturbation
        if self.config['augmentation'].get('speed_range'):
            speed_factor = np.random.uniform(*self.config['augmentation']['speed_range'])
            audio = librosa.effects.time_stretch(audio, rate=speed_factor)

        # Apply pitch shift
        if self.config['augmentation'].get('pitch_shift_range'):
            pitch_shift = np.random.randint(*self.config['augmentation']['pitch_shift_range'])
            audio = librosa.effects.pitch_shift(
                audio,
                sr=self.sampling_rate,
                n_steps=pitch_shift
            )

        return audio

    def process_text_corpus(
        self,
        input_file: str,
        output_dir: str,
        json_output: str,
        batch_size: int = 16,
        sample_size: int = None,
        random_seed: int = None
    ):
        """Process entire text corpus and generate synthetic audio with metadata."""
        Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Sample text corpus if sample_size is specified
        if sample_size is not None:
            # Create a sampled text file in the same directory as json_output
            sampled_text_file = str(Path(json_output).parent / 'sampled_corpus.txt')
            texts = self.sample_text_corpus(
                input_file=input_file,
                sample_size=sample_size,
                output_file=sampled_text_file,
                random_seed=random_seed
            )
        else:
            # Read entire corpus
            with open(input_file, 'r', encoding='utf-8') as f:
                texts = [line.strip() for line in f if line.strip()]

        logger.info(f"Processing {len(texts)} text entries...")

        metadata = []
        # Process in batches for efficiency
        for batch_idx in tqdm(range(0, len(texts), batch_size), desc="Generating Audio"):
            batch_texts = texts[batch_idx:batch_idx + batch_size]

            # Tokenize batch
            inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True).to(self.device)

            # Generate audio for batch
            with torch.no_grad():
                waveforms = self.model(**inputs).waveform  # [batch, time]

            # Process each item in batch
            for idx, (text, waveform) in enumerate(zip(batch_texts, waveforms)):
                try:
                    output_path = os.path.join(output_dir, f"synthetic_{batch_idx + idx:06d}.wav")

                    # Move to CPU and convert to numpy
                    wav = waveform.cpu().numpy()

                    # Apply augmentation if enabled
                    if self.config['augmentation']['enabled']:
                        wav = self._augment_audio(wav)

                    # Save audio
                    sf.write(output_path, wav, self.sampling_rate)

                    metadata.append({
                        "audio_path": output_path,
                        "text": text,
                        "duration": len(wav) / self.sampling_rate
                    })

                except Exception as e:
                    logger.error(f"Error processing text entry {batch_idx + idx}: {e}")
                    continue

        # Save metadata
        with open(json_output, 'w', encoding='utf-8') as f:
            for entry in metadata:
                f.write(json.dumps(entry, ensure_ascii=False) + '\n')

        logger.info(f"Generated {len(metadata)} audio files with metadata at {json_output}")
        return metadata

    def evaluate_audio_quality(self, audio_path: str) -> Dict[str, float]:
        """Evaluate the quality metrics of generated audio."""
        try:
            audio, sr = librosa.load(audio_path, sr=self.sampling_rate)
            metrics = {}

            # Signal-to-noise ratio
            noise_floor = np.mean(np.abs(audio[audio < np.mean(audio)]))
            signal = np.mean(np.abs(audio[audio >= np.mean(audio)]))
            metrics['snr'] = 20 * np.log10(signal / noise_floor) if noise_floor > 0 else float('inf')

            # RMS energy
            metrics['rms'] = float(np.sqrt(np.mean(audio**2)))

            # Zero-crossing rate
            metrics['zcr'] = float(np.mean(librosa.feature.zero_crossing_rate(audio)))

            # Spectral centroid
            cent = librosa.feature.spectral_centroid(y=audio, sr=sr)
            metrics['spectral_centroid'] = float(np.mean(cent))

            return metrics

        except Exception as e:
            logger.error(f"Error calculating audio quality metrics: {e}")
            return {}

## Configuration

Set up the configuration for synthetic data generation including TTS model, augmentation parameters, and output paths.

In [None]:
output_dir = '/content/drive/My Drive/Colab Outputs/audio'
json_output = '/content/drive/My Drive/Colab Outputs/metadata.jsonl'

In [None]:
config = {
    'tts': {
        'model': 'facebook/mms-tts-tgk',
        'language': 'tgk'
    },
    'data': {
        'sample_rate': None
    },
    'augmentation': {
        'enabled': True,
        'noise_factor': 0.003,
        'speed_range': (0.9, 1.1),
        'pitch_shift_range': (-2, 2)
    }
}

generator = SyntheticDataGenerator(config)

metadata = generator.process_text_corpus(
    input_file='final_filtered_newest.txt',
    output_dir=output_dir,
    json_output=json_output,
    batch_size=16,
    sample_size=32000,
    random_seed=42
)

Generating Audio: 100%|██████████| 2000/2000 [1:57:03<00:00,  3.51s/it]
