In [2]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
import librosa
import nltk
import soundfile as sf
from sklearn.model_selection import train_test_split
import re
import pandas as pd
from tqdm.notebook import tqdm
from functools import lru_cache
import gc


In [3]:
nltk.download('averaged_perceptron_tagger_eng')

[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


True

In [4]:
# Set the random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [5]:
# Install required packages
!pip install transformers
!pip install pydub
!pip install phonemizer
!pip install g2p_en
!pip install praatio
!pip install pyworld
!pip install torchtext
!pip install unidecode




In [6]:
# Import additional libraries after installation
from transformers import AutoTokenizer, AutoModel
from g2p_en import G2p
import pyworld as pw
from pydub import AudioSegment
import unidecode

In [7]:
class AudioProcessor:
    """Class to process audio files and extract features"""
    
    def __init__(self, sr=22050, n_fft=1024, hop_length=256, n_mels=80):
        self.sr = sr  # Sample rate
        self.n_fft = n_fft  # FFT window size
        self.hop_length = hop_length  # Hop length for FFT
        self.n_mels = n_mels  # Number of mel bands
    
    def load_audio(self, file_path):
        """Load audio file and resample if necessary"""
        audio, sr = librosa.load(file_path, sr=self.sr)
        return audio
    
    def extract_mel_spectrogram(self, audio):
        """Extract mel spectrogram from audio"""
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels
        )
        # Convert to log scale (dB)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        return log_mel_spec
    
    def extract_pitch(self, audio):
        """Extract pitch (f0) using PyWorld's HARVEST algorithm"""
        # Convert audio to float64
        audio_64 = audio.astype(np.float64)
        
        # Extract pitch
        f0, t = pw.harvest(audio_64, self.sr, frame_period=self.hop_length/self.sr*1000)
        
        # Normalize pitch
        f0[f0 > 0] = (f0[f0 > 0] - np.min(f0[f0 > 0])) / (np.max(f0[f0 > 0]) - np.min(f0[f0 > 0]))
        
        return f0
    
    def extract_energy(self, mel_spec):
        """Extract energy from mel spectrogram"""
        energy = np.linalg.norm(mel_spec, axis=0)
        # Normalize energy
        energy = (energy - np.min(energy)) / (np.max(energy) - np.min(energy))
        return energy
    
    def process_audio_file(self, file_path):
        """Process a single audio file and extract all features"""
        audio = self.load_audio(file_path)
        mel_spec = self.extract_mel_spectrogram(audio)
        pitch = self.extract_pitch(audio)
        energy = self.extract_energy(mel_spec)
        
        # Make sure pitch matches the number of frames in mel spectrogram
        if len(pitch) < mel_spec.shape[1]:
            pitch = np.pad(pitch, (0, mel_spec.shape[1] - len(pitch)))
        elif len(pitch) > mel_spec.shape[1]:
            pitch = pitch[:mel_spec.shape[1]]
            
        # Make sure energy matches the number of frames in mel spectrogram
        if len(energy) < mel_spec.shape[1]:
            energy = np.pad(energy, (0, mel_spec.shape[1] - len(energy)))
        elif len(energy) > mel_spec.shape[1]:
            energy = energy[:mel_spec.shape[1]]
        
        return {
            'audio': audio,
            'mel_spectrogram': mel_spec,
            'pitch': pitch,
            'energy': energy
        }

In [8]:
class TextProcessor:
    """Class to process Roman Urdu text"""
    
    def __init__(self):
        # Initialize tokenizer for Roman Urdu
        # Since there's no specific tokenizer for Roman Urdu, 
        # we'll use a general-purpose one and fine-tune later
        self.tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
        self.phonemizer = G2p()  # English phonemizer as a starting point
        
        # Map for common Roman Urdu sounds to phonemes
        self.urdu_phoneme_map = {
            'aa': 'ɑː',
            'ee': 'iː',
            'oo': 'uː',
            'ai': 'aɪ',
            'ae': 'eɪ',
            'ch': 'tʃ',
            'sh': 'ʃ',
            'kh': 'x',
            'gh': 'ɣ',
            'ph': 'f',
            'th': 'θ',
            'dh': 'ð',
            'ng': 'ŋ',
            # Add more mappings as needed
        }
    
    def clean_text(self, text):
        """Clean Roman Urdu text"""
        # Remove special characters but keep spaces and basic punctuation
        text = re.sub(r'[^\w\s\',\.!?]', '', text)
        # Normalize spaces
        text = re.sub(r'\s+', ' ', text)
        return text.strip()
    
    def text_to_sequence(self, text):
        """Convert text to sequence of token IDs"""
        clean_text = self.clean_text(text)
        tokens = self.tokenizer(clean_text, return_tensors="pt")
        return tokens['input_ids'][0]
    
    def get_phoneme_sequence(self, text):
        """Convert Roman Urdu text to approximate phoneme sequence"""
        clean_text = self.clean_text(text)
        
        # Apply Roman Urdu phoneme mappings
        for roman, phoneme in self.urdu_phoneme_map.items():
            clean_text = clean_text.replace(roman, phoneme)
        
        # Use G2p for remaining text (approximation)
        phonemes = self.phonemizer(clean_text)
        return phonemes



In [9]:
class LyricsProcessor:
    """Class to process and align lyrics with audio"""
    
    def __init__(self, text_processor):
        self.text_processor = text_processor
    
    def load_lyrics(self, lyrics_file):
        """Load lyrics from a file"""
        with open(lyrics_file, 'r', encoding='utf-8') as f:
            lyrics = f.read()
        return lyrics
    
    def parse_lyrics_with_timestamps(self, lyrics_file):
        """
        Parse lyrics with timestamps if available
        Expected format: [MM:SS.ms] Lyric line
        """
        lines = []
        timestamps = []
        
        with open(lyrics_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                
                # Try to extract timestamp
                timestamp_match = re.match(r'\[(\d+):(\d+)\.(\d+)\](.*)', line)
                if timestamp_match:
                    minutes, seconds, milliseconds, text = timestamp_match.groups()
                    time_in_seconds = int(minutes) * 60 + int(seconds) + int(milliseconds) / 1000
                    timestamps.append(time_in_seconds)
                    lines.append(text.strip())
                else:
                    # No timestamp, just add the line
                    lines.append(line)
                    # If we already have lines with timestamps, interpolate
                    if timestamps:
                        timestamps.append(-1)  # Mark for later interpolation
        
        return lines, timestamps
    
    def process_lyrics(self, lyrics_text):
        """Process lyrics text into phoneme and token sequences"""
        lines = [line.strip() for line in lyrics_text.split('\n') if line.strip()]
        
        phoneme_sequences = []
        token_sequences = []
        
        for line in lines:
            phonemes = self.text_processor.get_phoneme_sequence(line)
            tokens = self.text_processor.text_to_sequence(line)
            
            phoneme_sequences.append(phonemes)
            token_sequences.append(tokens)
        
        return {
            'lines': lines,
            'phoneme_sequences': phoneme_sequences,
            'token_sequences': token_sequences
        }
    
    def estimate_alignment(self, audio_length, num_lines):
        """
        Estimate alignment between lyrics and audio when timestamps aren't available
        Returns estimated timestamp for each line in seconds
        """
        # Simple linear interpolation
        return np.linspace(0, audio_length, num_lines + 1)[:-1]



In [28]:
class SingerDataset(Dataset):
    """Memory-optimized dataset for singer voice data"""
    
    def __init__(self, audio_files, lyrics_files, audio_processor, lyrics_processor, 
                 max_len=500, cache_size=100, memory_limit_mb=12000):
        self.audio_files = audio_files
        self.lyrics_files = lyrics_files
        self.audio_processor = audio_processor
        self.lyrics_processor = lyrics_processor
        self.max_len = max_len
        self.cache_size = cache_size
        self.memory_limit_mb = memory_limit_mb
        
        # Adjust cache size based on available memory
        self._adjust_cache_size()
        
        # Instead of loading all data, just prepare metadata and file mapping
        self.file_pairs = list(zip(audio_files, lyrics_files))
        
        # Validate files exist
        self._validate_files()
        
        # Pre-index the dataset structure without loading full data
        self.sample_indices = self._index_dataset()
        
        # Set up memory monitoring
        self.last_memory_check = 0
        self.memory_check_interval = 50  # Check memory every 50 getitem calls
        
        print(f"Singer dataset initialized with {len(self.sample_indices)} segments across {len(self.file_pairs)} files")
        print(f"Using a cache size of {self.cache_size} and memory limit of {self.memory_limit_mb}MB")
    
    def _validate_files(self):
        """Verify that all files exist before starting"""
        for audio_file, lyrics_file in self.file_pairs:
            if not os.path.exists(audio_file):
                raise FileNotFoundError(f"Audio file not found: {audio_file}")
            if not os.path.exists(lyrics_file):
                raise FileNotFoundError(f"Lyrics file not found: {lyrics_file}")
    
    def _index_dataset(self):
        """Create an index of all valid segments without loading their content"""
        indices = []
        
        for file_idx, (audio_file, lyrics_file) in enumerate(self.file_pairs):
            try:
                # Process audio file minimally to get length
                # Use existing process_audio_file but only extract the length
                audio_features = self.audio_processor.process_audio_file(audio_file)
                audio_length = len(audio_features['audio']) / self.audio_processor.sr
                
                # Free memory after extracting just the length
                del audio_features
                gc.collect()
                
                # Process lyrics to get line count
                lyrics_text = self.lyrics_processor.load_lyrics(lyrics_file)
                lyrics_data = self.lyrics_processor.process_lyrics(lyrics_text)
                lines_count = len(lyrics_data['lines'])
                
                # Free memory after getting line count
                del lyrics_data
                gc.collect()
                
                # Estimate timestamps
                timestamps = self.lyrics_processor.estimate_alignment(audio_length, lines_count)
                
                # Index each potential segment
                for line_idx in range(lines_count):
                    start_time = timestamps[line_idx]
                    end_time = timestamps[line_idx+1] if line_idx+1 < len(timestamps) else audio_length
                    
                    start_frame = int(start_time * self.audio_processor.sr / self.audio_processor.hop_length)
                    end_frame = int(end_time * self.audio_processor.sr / self.audio_processor.hop_length)
                    
                    # Skip if segment is too short
                    if end_frame - start_frame >= 5:
                        indices.append((file_idx, line_idx))
                
            except Exception as e:
                print(f"Error indexing {audio_file} and {lyrics_file}: {e}")
        
        return indices
    
    @lru_cache(maxsize=16)  # Reduced cache size to limit memory usage
    def _process_audio_file(self, audio_file):
        """Process audio file with caching"""
        try:
            features = self.audio_processor.process_audio_file(audio_file)
            
            # Create a copy to ensure we don't retain references to large objects
            result = {
                'audio': features['audio'].copy() if 'audio' in features else None,
                'mel_spectrogram': features['mel_spectrogram'].copy() if 'mel_spectrogram' in features else None,
                'pitch': features['pitch'].copy() if 'pitch' in features else None,
                'energy': features['energy'].copy() if 'energy' in features else None
            }
            return result
        except Exception as e:
            print(f"Error processing audio file {audio_file}: {e}")
            return None
    
    @lru_cache(maxsize=32)  # Cache a limited number of processed lyrics files
    def _process_lyrics_file(self, lyrics_file):
        """Process lyrics file with caching"""
        try:
            lyrics_text = self.lyrics_processor.load_lyrics(lyrics_file)
            return self.lyrics_processor.process_lyrics(lyrics_text)
        except Exception as e:
            print(f"Error processing lyrics file {lyrics_file}: {e}")
            return None
    
    def _extract_segment(self, audio_features, lyrics_data, line_idx, audio_length):
        """Extract a specific segment from processed audio and lyrics"""
        try:
            # Ensure we have all required features
            if (audio_features is None or 
                'mel_spectrogram' not in audio_features or 
                'pitch' not in audio_features or 
                'energy' not in audio_features):
                return None
                
            # Estimate timestamps for this line
            timestamps = self.lyrics_processor.estimate_alignment(
                audio_length, len(lyrics_data['lines'])
            )
            
            # Calculate start and end frames
            start_time = timestamps[line_idx]
            end_time = timestamps[line_idx+1] if line_idx+1 < len(timestamps) else audio_length
            
            start_frame = int(start_time * self.audio_processor.sr / self.audio_processor.hop_length)
            end_frame = int(end_time * self.audio_processor.sr / self.audio_processor.hop_length)
            
            # Skip if segment is too short
            if end_frame - start_frame < 5:
                return None
                
            # Verify we have enough data in the features
            mel_spec = audio_features['mel_spectrogram']
            pitch = audio_features['pitch']
            energy = audio_features['energy']
            
            if (mel_spec.shape[1] <= end_frame or 
                len(pitch) <= end_frame or 
                len(energy) <= end_frame):
                # Our estimated end frame exceeds available data
                end_frame = min(mel_spec.shape[1], len(pitch), len(energy))
                if end_frame - start_frame < 5:  # Still check if segment is viable
                    return None
            
            # Extract segment features with explicit copies to avoid memory leaks
            try:
                mel_segment = mel_spec[:, start_frame:end_frame].copy()
                pitch_segment = pitch[start_frame:end_frame].copy()
                energy_segment = energy[start_frame:end_frame].copy()
            except IndexError:
                print(f"Index error for segment {line_idx}, frames {start_frame}:{end_frame}")
                return None
            
            # Skip if any feature is empty
            if mel_segment.size == 0 or len(pitch_segment) == 0 or len(energy_segment) == 0:
                return None
            
            # Pad or trim features to max_len
            if mel_segment.shape[1] > self.max_len:
                mel_segment = mel_segment[:, :self.max_len]
                pitch_segment = pitch_segment[:self.max_len]
                energy_segment = energy_segment[:self.max_len]
            else:
                pad_len = self.max_len - mel_segment.shape[1]
                mel_segment = np.pad(mel_segment, ((0, 0), (0, pad_len)))
                pitch_segment = np.pad(pitch_segment, (0, pad_len))
                energy_segment = np.pad(energy_segment, (0, pad_len))
            
            # Create mask
            mask = np.ones(self.max_len)
            mask[mel_segment.shape[1]:] = 0
            
            return {
                'text': lyrics_data['lines'][line_idx],
                'phonemes': lyrics_data['phoneme_sequences'][line_idx],
                'tokens': lyrics_data['token_sequences'][line_idx],
                'mel_spectrogram': mel_segment,
                'pitch': pitch_segment,
                'energy': energy_segment,
                'mask': mask
            }
            
        except Exception as e:
            print(f"Error extracting segment {line_idx}: {e}")
            return None
    
    def __len__(self):
        return len(self.sample_indices)
    
    def _adjust_cache_size(self):
        """Dynamically adjust cache size based on available memory"""
        try:
            available_memory = psutil.virtual_memory().available / (1024 * 1024)  # Convert to MB
            if available_memory < self.memory_limit_mb * 0.2:  # If less than 20% of limit is available
                self.cache_size = max(10, self.cache_size // 2)  # Reduce cache size
            elif available_memory > self.memory_limit_mb * 0.8:  # If more than 80% of limit is available
                self.cache_size = min(200, self.cache_size * 2)  # Increase cache size
        except:
            # Default to conservative cache size if psutil isn't available
            self.cache_size = 50
            
    def _check_memory_usage(self):
        """Check memory usage and clear caches if needed"""
        try:
            used_memory_mb = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
            if used_memory_mb > self.memory_limit_mb * 0.8:  # Using more than 80% of limit
                print(f"Memory usage high ({used_memory_mb:.1f}MB), clearing caches")
                self.cleanup()
                return True
        except:
            pass  # Silently fail if psutil isn't available
        return False

    def __getitem__(self, idx):
        # Periodically check memory usage
        if idx - self.last_memory_check > self.memory_check_interval:
            self.last_memory_check = idx
            memory_cleared = self._check_memory_usage()
        
        try:
            # Get file and line indices
            file_idx, line_idx = self.sample_indices[idx]
            audio_file, lyrics_file = self.file_pairs[file_idx]
            
            # Process files as needed
            audio_features = self._process_audio_file(audio_file)
            if audio_features is None:
                return self._get_empty_item()
                
            lyrics_data = self._process_lyrics_file(lyrics_file)
            if lyrics_data is None:
                return self._get_empty_item()
            
            # Extract the specific segment
            audio_length = len(audio_features['audio']) / self.audio_processor.sr if 'audio' in audio_features else 0
            if audio_length == 0:
                return self._get_empty_item()
                
            segment = self._extract_segment(audio_features, lyrics_data, line_idx, audio_length)
            
            if segment is None:
                # Return an empty placeholder if extraction failed
                return self._get_empty_item()
            
            # Convert to tensors
            tokens = torch.tensor(segment['tokens'], dtype=torch.long)
            mel_spec = torch.tensor(segment['mel_spectrogram'], dtype=torch.float)
            pitch = torch.tensor(segment['pitch'], dtype=torch.float)
            energy = torch.tensor(segment['energy'], dtype=torch.float)
            mask = torch.tensor(segment['mask'], dtype=torch.float)
            
            # Force garbage collection periodically
            if idx % self.cache_size == 0:
                gc.collect()
            
            return {
                'text': segment['text'],
                'phonemes': segment['phonemes'],
                'tokens': tokens,
                'mel_spectrogram': mel_spec,
                'pitch': pitch, 
                'energy': energy,
                'mask': mask
            }
        
        except Exception as e:
            print(f"Error in __getitem__ for index {idx}: {e}")
            return self._get_empty_item()
    
    def _get_empty_item(self):
        """Return an empty placeholder item for error cases"""
        return {
            'text': "",
            'phonemes': [],
            'tokens': torch.zeros(1, dtype=torch.long),
            'mel_spectrogram': torch.zeros((80, self.max_len), dtype=torch.float),
            'pitch': torch.zeros(self.max_len, dtype=torch.float),
            'energy': torch.zeros(self.max_len, dtype=torch.float),
            'mask': torch.zeros(self.max_len, dtype=torch.float)
        }
    
    def cleanup(self):
        """Explicitly clear caches and run garbage collection"""
        self._process_audio_file.cache_clear()
        self._process_lyrics_file.cache_clear()
        gc.collect()
        gc.collect()  # Run twice to ensure collection of cyclical references
        
    def reduce_memory_usage(self):
        """More aggressive memory reduction when running out of RAM"""
        # Clear all caches
        self.cleanup()
        
        # Reduce cache sizes
        self.cache_size = max(10, self.cache_size // 2)
        
        # Force Python to return memory to OS if possible
        try:
            import ctypes
            ctypes.CDLL('libc.so.6').malloc_trim(0)
        except:
            pass
            
    def __del__(self):
        """Cleanup when object is destroyed"""
        self.cleanup()

In [11]:
class TextEncoder(nn.Module):
    """Text encoder module"""
    
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, n_layers=3, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if n_layers > 1 else 0
        )
        self.projection = nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x shape: (batch_size, seq_len)
        embedded = self.dropout(self.embedding(x))
        # embedded shape: (batch_size, seq_len, embed_dim)
        
        outputs, _ = self.lstm(embedded)
        # outputs shape: (batch_size, seq_len, hidden_dim * 2)
        
        outputs = self.projection(outputs)
        # outputs shape: (batch_size, seq_len, hidden_dim)
        
        return outputs


In [12]:
class VarianceAdaptor(nn.Module):
    """Adapts variance information (pitch, energy) to the model"""
    
    def __init__(self, hidden_dim=512, kernel_size=3, dropout=0.1):
        super().__init__()
        
        # Pitch predictor
        self.pitch_predictor = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        # Energy predictor
        self.energy_predictor = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        # Duration predictor (for training)
        self.duration_predictor = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        # Embeddings for transforming predicted values
        self.pitch_embedding = nn.Linear(1, hidden_dim)
        self.energy_embedding = nn.Linear(1, hidden_dim)
    
    def forward(self, x, pitch=None, energy=None):
        # x shape: (batch_size, seq_len, hidden_dim)
        
        # Transpose for Conv1d
        x_conv = x.transpose(1, 2)
        
        # Predict pitch and energy if not provided
        if pitch is None:
            pitch_pred = self.pitch_predictor(x_conv).transpose(1, 2)
        else:
            pitch_pred = pitch.unsqueeze(-1)
        
        if energy is None:
            energy_pred = self.energy_predictor(x_conv).transpose(1, 2)
        else:
            energy_pred = energy.unsqueeze(-1)
        
        # Get embeddings
        pitch_embedding = self.pitch_embedding(pitch_pred)
        energy_embedding = self.energy_embedding(energy_pred)
        
        # Add variance embeddings to x
        outputs = x + pitch_embedding + energy_embedding
        
        return outputs, pitch_pred.squeeze(-1), energy_pred.squeeze(-1)


In [13]:
class Decoder(nn.Module):
    """Decoder to generate mel spectrograms"""
    
    def __init__(self, hidden_dim=512, n_mels=80, kernel_size=5, n_layers=4, dropout=0.1):
        super().__init__()
        
        self.pre_net = nn.Sequential(
            nn.Linear(n_mels, hidden_dim//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.attention_rnn = nn.LSTMCell(hidden_dim * 2, hidden_dim)
        
        self.decoder_rnn = nn.LSTMCell(hidden_dim, hidden_dim)
        
        self.attention = nn.Linear(hidden_dim * 2, 1, bias=False)
        
        self.projection = nn.Linear(hidden_dim, n_mels)
        
        self.post_net = nn.Sequential(
            nn.Conv1d(n_mels, hidden_dim, kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Dropout(dropout),
            *[nn.Sequential(
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=kernel_size//2),
                nn.BatchNorm1d(hidden_dim),
                nn.Tanh(),
                nn.Dropout(dropout)
            ) for _ in range(n_layers - 2)],
            nn.Conv1d(hidden_dim, n_mels, kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(n_mels),
            nn.Dropout(dropout)
        )
    
    def forward(self, encoder_outputs, mel_targets=None, teacher_forcing_ratio=0.5):
        # encoder_outputs shape: (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = encoder_outputs.shape
        
        # Initialize decoder states
        attn_hidden = torch.zeros(batch_size, hidden_dim).to(encoder_outputs.device)
        attn_cell = torch.zeros(batch_size, hidden_dim).to(encoder_outputs.device)
        
        decoder_hidden = torch.zeros(batch_size, hidden_dim).to(encoder_outputs.device)
        decoder_cell = torch.zeros(batch_size, hidden_dim).to(encoder_outputs.device)
        
        # Initialize attention weights
        attn_weights = torch.zeros(batch_size, seq_len).to(encoder_outputs.device)
        
        # Initialize first input as zeros
        decoder_input = torch.zeros(batch_size, encoder_outputs.shape[-1]).to(encoder_outputs.device)
        
        # Initialize outputs
        mel_outputs = []
        
        # Generate target sequence length
        target_len = mel_targets.shape[1] if mel_targets is not None else 200
        
        for t in range(target_len):
            # Use teacher forcing
            if mel_targets is not None and torch.rand(1) < teacher_forcing_ratio:
                current_input = mel_targets[:, t, :] if t < mel_targets.shape[1] else decoder_input
            else:
                current_input = decoder_input
            
            # Apply pre-net
            prenet_out = self.pre_net(current_input)
            
            # Attention RNN
            attn_input = torch.cat([prenet_out, decoder_hidden], dim=1)
            attn_hidden, attn_cell = self.attention_rnn(attn_input, (attn_hidden, attn_cell))
            
            # Calculate attention
            attn_energy = self.attention(
                torch.cat([attn_hidden.unsqueeze(1).expand(-1, seq_len, -1), 
                          encoder_outputs], dim=2)
            ).squeeze(-1)
            
            attn_weights = F.softmax(attn_energy, dim=1)
            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
            
            # Decoder RNN
            decoder_input = torch.cat([attn_hidden, context], dim=1)
            decoder_hidden, decoder_cell = self.decoder_rnn(decoder_input, (decoder_hidden, decoder_cell))
            
            # Project to get mel output
            mel_output = self.projection(decoder_hidden)
            mel_outputs.append(mel_output)
            
            # Use this output as next input
            decoder_input = mel_output
        
        # Stack mel outputs
        mel_outputs = torch.stack(mel_outputs, dim=1)
        
        # Apply post-net and add residual connection
        post_output = self.post_net(mel_outputs.transpose(1, 2)).transpose(1, 2)
        mel_outputs_postnet = mel_outputs + post_output
        
        return mel_outputs, mel_outputs_postnet


In [14]:
class VoiceCloning(nn.Module):
    """Complete voice cloning model"""
    
    def __init__(self, vocab_size, n_mels=80, hidden_dim=512):
        super().__init__()
        
        self.text_encoder = TextEncoder(vocab_size, hidden_dim=hidden_dim)
        self.variance_adaptor = VarianceAdaptor(hidden_dim=hidden_dim)
        self.decoder = Decoder(hidden_dim=hidden_dim, n_mels=n_mels)
    
    def forward(self, tokens, pitch=None, energy=None, mel_targets=None, teacher_forcing_ratio=0.5):
        # Encode text
        encoder_outputs = self.text_encoder(tokens)
        
        # Apply variance adaptor
        variance_outputs, pitch_pred, energy_pred = self.variance_adaptor(
            encoder_outputs, pitch, energy
        )
        
        # Generate mel spectrograms
        mel_outputs, mel_outputs_postnet = self.decoder(
            variance_outputs, mel_targets, teacher_forcing_ratio
        )
        
        return {
            'mel_outputs': mel_outputs,
            'mel_outputs_postnet': mel_outputs_postnet,
            'pitch_pred': pitch_pred,
            'energy_pred': energy_pred
        }


In [15]:
class WaveNetVocoder(nn.Module):
    """WaveNet-based vocoder to convert mel spectrograms to waveform"""
    
    def __init__(self, n_mels=80, channels=256, kernel_size=3, n_layers=24, dilation_cycle=12):
        super().__init__()
        
        # Input layer
        self.input_conv = nn.Conv1d(n_mels, channels, kernel_size=1)
        
        # Residual blocks with dilated convolutions
        self.residual_blocks = nn.ModuleList([
            nn.ModuleDict({
                'dilated_conv': nn.Conv1d(
                    channels, 2 * channels, kernel_size, 
                    dilation=2 ** (i % dilation_cycle), padding='same'
                ),
                'res_conv': nn.Conv1d(channels, channels, kernel_size=1),
                'skip_conv': nn.Conv1d(channels, channels, kernel_size=1)
            }) for i in range(n_layers)
        ])
        
        # Output layers
        self.output_layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(channels, channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(channels, 1, kernel_size=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # x shape: (batch_size, n_mels, time_steps)
        
        # Apply input layer
        x = self.input_conv(x)
        
        # Initialize skip connections
        skip_connections = 0
        
        # Apply residual blocks
        for block in self.residual_blocks:
            # Dilated convolution
            h = block['dilated_conv'](x)
            h_tanh, h_sigmoid = torch.split(h, h.size(1) // 2, dim=1)
            h = torch.tanh(h_tanh) * torch.sigmoid(h_sigmoid)
            
            # Residual and skip connections
            res = block['res_conv'](h)
            skip = block['skip_conv'](h)
            
            x = x + res
            skip_connections = skip_connections + skip
        
        # Apply output layers
        output = self.output_layers(skip_connections)
        
        return output


In [16]:
def create_alignment(text, mel):
    """Create a simple alignment between text and mel frames"""
    text_len = text.shape[1]
    mel_len = mel.shape[2]
    
    # Simple linear alignment
    alignment = torch.zeros(1, text_len, mel_len)
    step = mel_len / text_len
    
    for i in range(text_len):
        start = int(i * step)
        end = int((i + 1) * step)
        alignment[0, i, start:end] = 1.0
    
    return alignment



In [17]:
def collate_fn(batch):
    """Collate function for DataLoader"""
    # Get max lengths
    max_token_len = max([len(item['tokens']) for item in batch])
    
    # Initialize tensors
    tokens = torch.zeros(len(batch), max_token_len, dtype=torch.long)
    mel_specs = torch.stack([item['mel_spectrogram'] for item in batch])
    pitches = torch.stack([item['pitch'] for item in batch])
    energies = torch.stack([item['energy'] for item in batch])
    masks = torch.stack([item['mask'] for item in batch])
    
    # Fill tokens tensor
    for i, item in enumerate(batch):
        tokens[i, :len(item['tokens'])] = item['tokens']
    
    return {
        'texts': [item['text'] for item in batch],
        'phonemes': [item['phonemes'] for item in batch],
        'tokens': tokens,
        'mel_spectrograms': mel_specs,
        'pitches': pitches,
        'energies': energies,
        'masks': masks
    }


In [18]:
def train_model(model, vocoder, train_loader, valid_loader, device, epochs=100):
    """Train the voice cloning model"""
    # Define optimizers
    optimizer = Adam(model.parameters(), lr=0.001)
    vocoder_optimizer = Adam(vocoder.parameters(), lr=0.0005)
    
    # Define loss functions
    mel_loss_fn = nn.L1Loss()
    pitch_loss_fn = nn.MSELoss()
    energy_loss_fn = nn.MSELoss()
    waveform_loss_fn = nn.L1Loss()
    
    # Learning rate schedulers
    model_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    vocoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        vocoder_optimizer, mode='min', factor=0.5, patience=8, verbose=True
    )
    
    # TensorBoard writer
    writer = SummaryWriter('runs/voice_cloning')
    
    # Training loop
    best_valid_loss = float('inf')
    
    for epoch in range(epochs):
        # Training
        model.train()
        vocoder.train()
        
        train_loss = 0
        train_mel_loss = 0
        train_pitch_loss = 0
        train_energy_loss = 0
        # Continue training loop
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]'):
            # Move batch to device
            tokens = batch['tokens'].to(device)
            mel_specs = batch['mel_spectrograms'].to(device)
            pitches = batch['pitches'].to(device)
            energies = batch['energies'].to(device)
            masks = batch['masks'].to(device)
            
            # Forward pass for voice cloning model
            outputs = model(tokens, pitches, energies, mel_specs)
            
            # Apply mask to outputs
            mel_outputs = outputs['mel_outputs'] * masks.unsqueeze(-1)
            mel_outputs_postnet = outputs['mel_outputs_postnet'] * masks.unsqueeze(-1)
            pitch_pred = outputs['pitch_pred'] * masks
            energy_pred = outputs['energy_pred'] * masks
            
            # Calculate losses
            mel_loss = mel_loss_fn(mel_outputs, mel_specs) + mel_loss_fn(mel_outputs_postnet, mel_specs)
            pitch_loss = pitch_loss_fn(pitch_pred, pitches)
            energy_loss = energy_loss_fn(energy_pred, energies)
            
            # Combine losses for voice cloning model
            loss = mel_loss + pitch_loss + energy_loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            # Forward pass for vocoder
            if epoch >= 5:  # Start training vocoder after a few epochs
                # Generate audio from mel spectrogram
                mel_specs_for_vocoder = mel_specs.transpose(1, 2)  # (batch, time_steps, n_mels) -> (batch, n_mels, time_steps)
                waveform_pred = vocoder(mel_specs_for_vocoder)
                
                # Load target waveform (placeholder - would need actual target audio)
                # This is a simplified placeholder; in practice, you'd align the waveform with the mel frames
                target_waveform = torch.randn_like(waveform_pred)  # Placeholder
                
                # Vocoder loss
                vocoder_loss = waveform_loss_fn(waveform_pred, target_waveform)
                
                # Backward pass and optimize vocoder
                vocoder_optimizer.zero_grad()
                vocoder_loss.backward()
                torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
                vocoder_optimizer.step()
            
            # Accumulate losses
            train_loss += loss.item()
            train_mel_loss += mel_loss.item()
            train_pitch_loss += pitch_loss.item()
            train_energy_loss += energy_loss.item()
        
        # Calculate average training losses
        train_loss /= len(train_loader)
        train_mel_loss /= len(train_loader)
        train_pitch_loss /= len(train_loader)
        train_energy_loss /= len(train_loader)
        
        # Validation
        model.eval()
        vocoder.eval()
        
        valid_loss = 0
        valid_mel_loss = 0
        valid_pitch_loss = 0
        valid_energy_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc=f'Epoch {epoch+1}/{epochs} [Valid]'):
                # Move batch to device
                tokens = batch['tokens'].to(device)
                mel_specs = batch['mel_spectrograms'].to(device)
                pitches = batch['pitches'].to(device)
                energies = batch['energies'].to(device)
                masks = batch['masks'].to(device)
                
                # Forward pass
                outputs = model(tokens, pitches, energies, mel_specs)
                
                # Apply mask to outputs
                mel_outputs = outputs['mel_outputs'] * masks.unsqueeze(-1)
                mel_outputs_postnet = outputs['mel_outputs_postnet'] * masks.unsqueeze(-1)
                pitch_pred = outputs['pitch_pred'] * masks
                energy_pred = outputs['energy_pred'] * masks
                
                # Calculate losses
                mel_loss = mel_loss_fn(mel_outputs, mel_specs) + mel_loss_fn(mel_outputs_postnet, mel_specs)
                pitch_loss = pitch_loss_fn(pitch_pred, pitches)
                energy_loss = energy_loss_fn(energy_pred, energies)
                
                # Combine losses
                loss = mel_loss + pitch_loss + energy_loss
                
                # Accumulate losses
                valid_loss += loss.item()
                valid_mel_loss += mel_loss.item()
                valid_pitch_loss += pitch_loss.item()
                valid_energy_loss += energy_loss.item()
        
        # Calculate average validation losses
        valid_loss /= len(valid_loader)
        valid_mel_loss /= len(valid_loader)
        valid_pitch_loss /= len(valid_loader)
        valid_energy_loss /= len(valid_loader)
        
        # Update learning rate schedulers
        model_scheduler.step(valid_loss)
        if epoch >= 5:
            vocoder_scheduler.step(valid_loss)
        
        # Log to TensorBoard
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/valid', valid_loss, epoch)
        writer.add_scalar('MelLoss/train', train_mel_loss, epoch)
        writer.add_scalar('MelLoss/valid', valid_mel_loss, epoch)
        writer.add_scalar('PitchLoss/train', train_pitch_loss, epoch)
        writer.add_scalar('PitchLoss/valid', valid_pitch_loss, epoch)
        writer.add_scalar('EnergyLoss/train', train_energy_loss, epoch)
        writer.add_scalar('EnergyLoss/valid', valid_energy_loss, epoch)
        
        # Print progress
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train Loss: {train_loss:.4f} (Mel: {train_mel_loss:.4f}, Pitch: {train_pitch_loss:.4f}, Energy: {train_energy_loss:.4f})')
        print(f'  Valid Loss: {valid_loss:.4f} (Mel: {valid_mel_loss:.4f}, Pitch: {valid_pitch_loss:.4f}, Energy: {valid_energy_loss:.4f})')
        
        # Save checkpoint if validation loss improved
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'vocoder_state_dict': vocoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'vocoder_optimizer_state_dict': vocoder_optimizer.state_dict(),
                'loss': valid_loss,
            }, f'checkpoints/voice_cloning_best.pt')
            print(f'  New best model saved!')
        
        # Save latest model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'vocoder_state_dict': vocoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vocoder_optimizer_state_dict': vocoder_optimizer.state_dict(),
            'loss': valid_loss,
        }, f'checkpoints/voice_cloning_latest.pt')
    
    # Close TensorBoard writer
    writer.close()
    
    return model, vocoder



In [19]:

class SongGenerator:
    """Class to generate songs from text input"""
    
    def __init__(self, model, vocoder, text_processor, audio_processor, device):
        self.model = model
        self.vocoder = vocoder
        self.text_processor = text_processor
        self.audio_processor = audio_processor
        self.device = device
        
        # Set models to evaluation mode
        self.model.eval()
        self.vocoder.eval()
    
    def generate_from_text(self, text, pitch_control=1.0, energy_control=1.0, speed_control=1.0):
        """Generate a song from input text"""
        # Process text
        tokens = self.text_processor.text_to_sequence(text)
        tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)
        
        # Generate with the model
        with torch.no_grad():
            # Forward pass through model
            outputs = self.model(tokens)
            
            # Apply control parameters
            pitch_pred = outputs['pitch_pred'] * pitch_control
            energy_pred = outputs['energy_pred'] * energy_control
            
            # Adjust speed by repeating/skipping frames
            mel_outputs = outputs['mel_outputs_postnet']
            if speed_control != 1.0:
                target_length = int(mel_outputs.shape[1] / speed_control)
                mel_outputs = F.interpolate(
                    mel_outputs.transpose(1, 2), 
                    size=target_length, 
                    mode='linear'
                ).transpose(1, 2)
            
            # Generate waveform with vocoder
            waveform = self.vocoder(mel_outputs.transpose(1, 2)).squeeze().cpu().numpy()
        
        return {
            'waveform': waveform,
            'mel_spectrogram': mel_outputs.squeeze().cpu().numpy(),
            'pitch': pitch_pred.squeeze().cpu().numpy(),
            'energy': energy_pred.squeeze().cpu().numpy()
        }
    
    def save_audio(self, waveform, file_path, sample_rate=22050):
        """Save audio to file"""
        sf.write(file_path, waveform, sample_rate)
    
    def plot_features(self, mel_spec, pitch, energy):
        """Plot generated features"""
        fig, axes = plt.subplots(3, 1, figsize=(12, 10))
        
        # Plot mel spectrogram
        im = axes[0].imshow(mel_spec, aspect='auto', origin='lower')
        axes[0].set_title('Mel Spectrogram')
        axes[0].set_ylabel('Mel Bins')
        fig.colorbar(im, ax=axes[0])
        
        # Plot pitch
        axes[1].plot(pitch)
        axes[1].set_title('Pitch')
        axes[1].set_ylabel('Normalized F0')
        
        # Plot energy
        axes[2].plot(energy)
        axes[2].set_title('Energy')
        axes[2].set_ylabel('Normalized Energy')
        axes[2].set_xlabel('Frames')
        
        plt.tight_layout()
        return fig



In [20]:
# Set paths
data_dir = 'data/'
audio_dir = '/kaggle/input/audio-dataset'
lyrics_dir = '/kaggle/input/lyrics'
output_dir = 'output/'
checkpoint_dir = 'checkpoints/'


In [21]:
# Create directories if they don't exist
os.makedirs(audio_dir, exist_ok=True)
os.makedirs(lyrics_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

In [22]:
# Initialize processors
audio_processor = AudioProcessor()
text_processor = TextProcessor()
lyrics_processor = LyricsProcessor(text_processor)


In [23]:
# Get file paths
audio_files = sorted(glob.glob(os.path.join(audio_dir, '*.wav')))
lyrics_files = sorted(glob.glob(os.path.join(lyrics_dir, '*.txt')))


In [24]:
# Ensure we have matching files
if len(audio_files) != len(lyrics_files):
    print(f"Warning: Number of audio files ({len(audio_files)}) doesn't match number of lyrics files ({len(lyrics_files)})")
    # Use the minimum number
    n_files = min(len(audio_files), len(lyrics_files))
    audio_files = audio_files[:n_files]
    lyrics_files = lyrics_files[:n_files]



In [25]:
# Split data
train_audio, valid_audio, train_lyrics, valid_lyrics = train_test_split(
    audio_files, lyrics_files, test_size=0.2, random_state=42
)


In [None]:
# Create datasets
train_dataset = SingerDataset(train_audio, train_lyrics, audio_processor, lyrics_processor)
valid_dataset = SingerDataset(valid_audio, valid_lyrics, audio_processor, lyrics_processor)


In [27]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=8, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=2
)

valid_loader = DataLoader(
    valid_dataset, 
    batch_size=8, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=2
)


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:

# Initialize models
vocab_size = len(text_processor.tokenizer)
model = VoiceCloning(vocab_size, n_mels=80, hidden_dim=512).to(device)
vocoder = WaveNetVocoder(n_mels=80).to(device)



In [None]:

# Check if checkpoint exists
checkpoint_path = os.path.join(checkpoint_dir, 'voice_cloning_latest.pt')
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    vocoder.load_state_dict(checkpoint['vocoder_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    start_epoch = 0
    print("Starting training from scratch")



In [None]:
# Train models
model, vocoder = train_model(
    model, 
    vocoder, 
    train_loader, 
    valid_loader, 
    device, 
    epochs=100
)


In [None]:
# Initialize song generator
song_generator = SongGenerator(model, vocoder, text_processor, audio_processor, device)



In [None]:
# Example: Generate a song from text
sample_text = "Mein tumhare saath hoon, har pal har lamha"
generated = song_generator.generate_from_text(
    sample_text, 
    pitch_control=1.0,
    energy_control=1.0,
    speed_control=1.0
)


In [None]:
# Save generated audio
output_path = os.path.join(output_dir, 'generated_song.wav')
song_generator.save_audio(generated['waveform'], output_path)
print(f"Generated song saved to {output_path}")


In [None]:
# Plot features
fig = song_generator.plot_features(
    generated['mel_spectrogram'],
    generated['pitch'],
    generated['energy']
)
fig.savefig(os.path.join(output_dir, 'feature_plot.png'))

print("Voice cloning complete!")