In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'd:\\MLOps-Project\\text-to-speech-using-mlops'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir : Path
    wave_path : Path
    csv_path : Path
    sr  : int
    n_fft : int
    n_stft : int
    frame_length : float
    win_length : int
    mel_freq : int
    max_mel_time : int
    max_db : int
    scale_db : int
    ref : float
    power : float
    norm_db : int
    ampl_multiplier : float
    ampl_amin : str
    db_multiplier : float
    ampl_ref : float
    ampl_power : float
    min_level_db : float
    frame_shift : float
    hop_length : int

In [4]:
from src.simpletts.constants import *
from src.simpletts.utils.common import create_directories, read_yaml

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])
        
    def get_data_transformation_config(self) -> DataTransformationConfig:
        config = self.config.data_transformation
        
        create_directories([config.root_dir])
        
        data_transformation_config = DataTransformationConfig(
            root_dir = config.root_dir,
            wave_path = config.wave_path,
            csv_path = config.csv_path,
            sr  = self.params.sr,
            n_fft = self.params.n_fft,
            n_stft = self.params.n_stft,
            frame_length = self.params.frame_length,
            win_length = self.params.win_length,
            mel_freq = self.params.mel_freq,
            max_mel_time = self.params.max_mel_time,
            max_db = self.params.max_db,
            scale_db = self.params.scale_db,
            ref = self.params.ref,
            power = self.params.power,
            norm_db = self.params.norm_db,
            ampl_multiplier = self.params.ampl_multiplier,
            ampl_amin = self.params.ampl_amin,
            db_multiplier = self.params.db_multiplier,
            ampl_ref = self.params.ampl_ref,
            ampl_power = self.params.ampl_power,
            min_level_db = self.params.min_level_db,
            frame_shift=self.params.frame_shift,
            hop_length=self.params.hop_length
        )
        
        
        return data_transformation_config


In [5]:
import torch


symbols = [
    'EOS', ' ', '!', ',', '-', '.', \
    ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', \
    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', \
    'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', \
    'â', 'è', 'é', 'ê', 'ü', '’', '“', '”' \
  ]


symbol_to_id = {
  s: i for i, s in enumerate(symbols)
}

def mask_from_seq_lengths(
    sequence_lengths: torch.Tensor, 
    max_length: int
) -> torch.BoolTensor:
   
    # (batch_size, max_length)
    ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
    range_tensor = ones.cumsum(dim=1)
    return sequence_lengths.unsqueeze(1) >= range_tensor 
  
def text_to_seq(text):
    text = text.lower()
    seq = []
    for s in text:
        _id = symbol_to_id.get(s, None)
        if _id is not None:
            seq.append(_id)
    seq.append(symbol_to_id['EOS'])
    return torch.IntTensor(seq)


In [6]:
import torchaudio
from torchaudio.functional import spectrogram

In [7]:
class AudioProcessor:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
        
        print(type(self.config.min_level_db))
        print(type(self.config.max_db))
        print(type(self.config.norm_db))
        print(type(self.config.ref))

        
        self.spec_transform = torchaudio.transforms.Spectrogram(
            n_fft=self.config.n_fft,
            win_length=self.config.win_length,
            hop_length=self.config.hop_length,
            power=self.config.power
        )
        
        self.mel_scale_transform = torchaudio.transforms.MelScale(
            n_mels=self.config.mel_freq,
            sample_rate=self.config.sr,
            n_stft=self.config.n_stft
        )
        
        self.mel_inverse_transform = torchaudio.transforms.InverseMelScale(
            n_mels=self.config.mel_freq,
            sample_rate=self.config.sr,
            n_stft=self.config.n_stft
        ).cuda()
        
        self.griffnlim_transform = torchaudio.transforms.GriffinLim(
            n_fft=self.config.n_fft,
            win_length=self.config.win_length,
            hop_length=self.config.hop_length
        ).cuda()
        
    def norm_mel_spec_db(mel_spec): 
        min_level_db = -100.0
        max_db = 100
        norm_db = 10
        ref = 4.0
        mel_spec = ((2.0*mel_spec - min_level_db) / (max_db/norm_db)) - 1.0
        mel_spec = torch.clip(mel_spec, -ref*norm_db, ref*norm_db)
        return mel_spec


    
    def denorm_mel_spec_db(self, mel_spec):
        mel_spec = (((1.0 + mel_spec) * (self.config.max_db / self.config.norm_db)) + self.config.min_level_db) / 2.0
        return mel_spec
    
    def pow_to_db_mel_spec(self, mel_spec):
        mel_spec = torchaudio.functional.amplitude_to_DB(
            mel_spec,
            multiplier=self.config.ampl_multiplier,
            amin=self.config.ampl_amin,
            db_multiplier=self.config.db_multiplier,
            top_db=self.config.max_db
        )
        mel_spec = mel_spec / self.config.scale_db
        return mel_spec
    
    def db_to_power_mel_spec(self, mel_spec):
        mel_spec = mel_spec * self.config.scale_db
        mel_spec = torchaudio.functional.DB_to_amplitude(
            mel_spec,
            ref=self.config.ampl_ref,
            power=self.config.ampl_power
        )
        return mel_spec
    
    def convert_to_mel_spec(self, wav):
        spec = self.spec_transform(wav)
        mel_spec = self.mel_scale_transform(spec)
        db_mel_spec = self.pow_to_db_mel_spec(mel_spec)
        db_mel_spec = db_mel_spec.squeeze(0)
        return db_mel_spec
    
    def inverse_mel_spec_to_wav(self, mel_spec):
        power_mel_spec = self.db_to_power_mel_spec(mel_spec)
        spectrogram = self.mel_inverse_transform(power_mel_spec)
        pseudo_wav = self.griffnlim_transform(spectrogram)
        return pseudo_wav


In [12]:
import torch.utils


class TextMelDataset(torch.utils.data.Dataset):
    def __init__(self, df, config: DataTransformationConfig):
        self.df = df
        self.cache = {}
        self.config = config
        self.audio_processor = AudioProcessor(config)  # Pass the config here
        
    def get_item(self, row):
        wav_id = row['wav']
        wav_path = f"{self.config.wave_path}/{wav_id}.wav"
        text = row['text_norm']
        text = text_to_seq(text)
        waveform, sample_rate = torchaudio.load(wav_path, normalize=True)
        assert sample_rate == self.config.sr
        mel = self.audio_processor.convert_to_mel_spec(waveform)
        return (text, mel)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        wave_id = row['wav']
        text_mel = self.cache.get(wave_id)
        if text_mel is None:
            text_mel = self.get_item(row)
            self.cache[wave_id] = text_mel
            
        return text_mel
    
    def __len__(self):
        return len(self.df)
    
    
    
    
    @staticmethod
    def text_mel_collate_fn(batch):
        text_length_max = torch.tensor(
            [text.shape[-1] for text, _ in batch], 
            dtype=torch.int32
        ).max()
        mel_length_max = torch.tensor(
            [mel.shape[-1] for _, mel in batch],
            dtype=torch.int32
        ).max()
    
        text_lengths = []
        mel_lengths = []
        texts_padded = []
        mels_padded = []
        for text, mel in batch:
            text_length = text.shape[-1]      
            text_padded = torch.nn.functional.pad(
                text,
                pad=[0, text_length_max-text_length],
                value=0
            )
            mel_length = mel.shape[-1]
            mel_padded = torch.nn.functional.pad(
                mel,
                pad=[0, mel_length_max-mel_length],
                value=0
            )
            text_lengths.append(text_length)    
            mel_lengths.append(mel_length)    
            texts_padded.append(text_padded)    
            mels_padded.append(mel_padded)
        text_lengths = torch.tensor(text_lengths, dtype=torch.int32)
        mel_lengths = torch.tensor(mel_lengths, dtype=torch.int32)
        texts_padded = torch.stack(texts_padded, 0)
        mels_padded = torch.stack(mels_padded, 0).transpose(1, 2)
        stop_token_padded = mask_from_seq_lengths(
            mel_lengths,
            mel_length_max
        )
        stop_token_padded = (~stop_token_padded).float()
        stop_token_padded[:, -1] = 1.0
    
        return texts_padded, \
            text_lengths, \
            mels_padded, \
            mel_lengths, \
            stop_token_padded

            
    
            
            
             

In [9]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch.utils
from src.simpletts.logging import logger
class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
        
    def load_data(self):
        df = pd.read_csv(self.config.csv_path)
        return df
    
    def split_data(self, data):
        train_df, test_df = train_test_split(data, test_size=0.4)
        return train_df, test_df
    
    def create_dataset(self, train_df, test_df):
        train_dataset = torch.utils.data.DataLoader(
            TextMelDataset(train_df, self.config),
            num_workers = 2,
            shuffle = True,
            sampler = None,
            batch_size = 1,
            pin_memory = True,
            drop_last = True,
            collate_fn = TextMelDataset.text_mel_collate_fn
        )
        
        test_dataset = torch.utils.data.DataLoader(
            TextMelDataset(test_df, self.config),
            num_workers = 2,
            shuffle = True,
            sampler = None,
            batch_size = 1,
            pin_memory = True,
            drop_last = True,
            collate_fn = TextMelDataset.text_mel_collate_fn
        )
        
        return  train_dataset, test_dataset
        
    
    def save_datasets(self, train_dataset, test_dataset):
        os.makedirs(self.config.root_dir, exist_ok=True)

        train_dataset_path = os.path.join(self.config.root_dir, 'train_dataset.pt')
        test_dataset_path = os.path.join(self.config.root_dir, 'test_dataset.pt')

        try:
            torch.save(train_dataset, train_dataset_path)
            torch.save(test_dataset, test_dataset_path)

            logger.info(f"Train dataset saved at: {train_dataset_path}")
            logger.info(f"Test dataset saved at: {test_dataset_path}")
        except Exception as e:
            logger.error(f"Error saving datasets: {str(e)}")
            raise e
    

In [11]:
try:
    # Initialize the ConfigurationManager and DataTransformation
    config_manager = ConfigurationManager()
    data_transformation_config = config_manager.get_data_transformation_config()
    data_transformation = DataTransformation(config=data_transformation_config)
    
    # Load the dataset
    dataset = data_transformation.load_data()
    
    # Split the data into train and test datasets
    train_df, test_df = data_transformation.split_data(dataset)
    
    # Create the datasets using the split data
    train_dataset, test_dataset = data_transformation.create_dataset(train_df, test_df)
    
    # Save the datasets to files
    data_transformation.save_datasets(train_dataset, test_dataset)
    
    # Initialize the AudioProcessor with the DataTransformationConfig
    audio_processor = AudioProcessor(config=data_transformation_config)
    
    

except Exception as e:
    logger.error(f"An error occurred during data transformation: {str(e)}")
    raise e

[2024-08-31 20:23:31,907: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-31 20:23:31,910: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-31 20:23:31,911: INFO: common: created directory at: artifacts]
[2024-08-31 20:23:31,912: INFO: common: created directory at: artifacts/data_transformation]
<class 'float'>
<class 'int'>
<class 'int'>
<class 'float'>
<class 'float'>
<class 'int'>
<class 'int'>
<class 'float'>
[2024-08-31 20:23:34,243: INFO: 2220319521: Train dataset saved at: artifacts/data_transformation\train_dataset.pt]
[2024-08-31 20:23:34,244: INFO: 2220319521: Test dataset saved at: artifacts/data_transformation\test_dataset.pt]
<class 'float'>
<class 'int'>
<class 'int'>
<class 'float'>
Inspecting train_dataset:


In [18]:
import torch
from torch.utils.data import DataLoader

# First, let's define the collate function here in the main module
def text_mel_collate_fn(batch):
    text_length_max = torch.tensor(
        [text.shape[-1] for text, _ in batch], 
        dtype=torch.int32
    ).max()
    mel_length_max = torch.tensor(
        [mel.shape[-1] for _, mel in batch],
        dtype=torch.int32
    ).max()

    text_lengths = []
    mel_lengths = []
    texts_padded = []
    mels_padded = []
    for text, mel in batch:
        text_length = text.shape[-1]      
        text_padded = torch.nn.functional.pad(
            text,
            pad=[0, text_length_max-text_length],
            value=0
        )
        mel_length = mel.shape[-1]
        mel_padded = torch.nn.functional.pad(
            mel,
            pad=[0, mel_length_max-mel_length],
            value=0
        )
        text_lengths.append(text_length)    
        mel_lengths.append(mel_length)    
        texts_padded.append(text_padded)    
        mels_padded.append(mel_padded)
    text_lengths = torch.tensor(text_lengths, dtype=torch.int32)
    mel_lengths = torch.tensor(mel_lengths, dtype=torch.int32)
    texts_padded = torch.stack(texts_padded, 0)
    mels_padded = torch.stack(mels_padded, 0).transpose(1, 2)
    stop_token_padded = mask_from_seq_lengths(
        mel_lengths,
        mel_length_max
    )
    stop_token_padded = (~stop_token_padded).float()
    stop_token_padded[:, -1] = 1.0

    return texts_padded, \
        text_lengths, \
        mels_padded, \
        mel_lengths, \
        stop_token_padded

# Now let's load the saved dataset and create a new DataLoader
try:
    # Load the saved dataset
    saved_dataset = torch.load('D:\\MLOps-Project\\text-to-speech-using-mlops\\artifacts\\train_loader.pt')
    
    # Recreate the DataLoader
    train_dataset = DataLoader(
        saved_dataset,
        num_workers=2,
        shuffle=True,
        batch_size=1,
        pin_memory=True,
        drop_last=True,
        collate_fn=text_mel_collate_fn
    )

    print("Inspecting train_dataset:")
    for batch in train_dataset:
        texts_padded, text_lengths, mels_padded, mel_lengths, stop_token_padded = batch
        
        print(f"\nBatch contents:")
        print(f"1. texts_padded shape: {texts_padded.shape}")
        print(f"   Sample text (indices): {texts_padded[0]}")
        print(f"2. text_lengths: {text_lengths}")
        print(f"3. mels_padded shape: {mels_padded.shape}")
        print(f"   Sample mel spectrogram shape: {mels_padded[0].shape}")
        print(f"4. mel_lengths: {mel_lengths}")
        print(f"5. stop_token_padded shape: {stop_token_padded.shape}")
        print(f"   Sample stop token: {stop_token_padded[0]}")
        break

except Exception as e:
    print(f"An error occurred: {str(e)}")
    raise e

  saved_dataset = torch.load('D:\\MLOps-Project\\text-to-speech-using-mlops\\artifacts\\train_loader.pt')


Inspecting train_dataset:
