# Set up

In [None]:
!pip install librosa
!pip install soundfile
!pip install speechbrain

In [None]:
import numpy as np
import librosa
import matplotlib.pyplot as plt
import soundfile as sf

import logging

from speechbrain.inference.vocoders import HIFIGAN
from speechbrain.inference.TTS import Tacotron2
from speechbrain.lobes.models.FastSpeech2 import mel_spectogram

import IPython.display as ipd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchaudio

from datasets import load_dataset # Expresso dataset
import tqdm.notebook

# Utils

In [None]:
EPS = 1e-6

def equals(a, b):
    return abs(a - b) < EPS

def dtw(a, b):
    n, m = a.shape[0], b.shape[0]
    dtw_matrix = np.full((n + 1, m + 1), np.inf)
    dtw_matrix[0, 0] = 0

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = np.linalg.norm(a[i - 1] - b[j - 1])  # Euclidean distance
            dtw_matrix[i, j] = cost + min(dtw_matrix[i - 1, j],    # Insertion
                                           dtw_matrix[i, j - 1],    # Deletion
                                           dtw_matrix[i - 1, j - 1]) # Match

    # Backtrack to find the optimal path
    i, j = n, m
    path = []

    while i > 0 or j > 0:
        path.append((i - 1, j - 1))
        if i > 0 and j > 0:
            if equals(dtw_matrix[i, j], dtw_matrix[i - 1, j - 1] + np.linalg.norm(a[i - 1] - b[j - 1])):
                i -= 1
                j -= 1
            elif equals(dtw_matrix[i, j], dtw_matrix[i - 1, j] + np.linalg.norm(a[i - 1] - b[j - 1])):
                i -= 1
            else:
                j -= 1
        elif i > 0:
            i -= 1
        else:
            j -= 1

    path.reverse()
    return dtw_matrix[n, m], path

def load_audio(file_path):
    y, sr = librosa.load(file_path, sr=None)
    return y, sr

def align(signal_a, signal_b, path):
    aligned_b = np.zeros_like(signal_a)

    for idx_a, idx_b in path:
        aligned_b[idx_a] = signal_b[idx_b]

    return aligned_b

def main(audio_file_1, audio_file_2):
    
    # 0. Load audio files
    audio_a, sr_a = load_audio(audio_file_1)
    audio_b, sr_b = load_audio(audio_file_2)

    # 1. Extract MFCC features
    mfcc_a = librosa.feature.mfcc(y=audio_a, sr=sr_a, n_mfcc=13).T
    mfcc_b = librosa.feature.mfcc(y=audio_b, sr=sr_b, n_mfcc=13).T

    # 2. Normalise MFCC features
    mfcc_a_normalised = (mfcc_a - np.mean(mfcc_a, axis=0))/(np.std(mfcc_a, axis=0))
    mfcc_b_normalised = (mfcc_b - np.mean(mfcc_b, axis=0))/(np.std(mfcc_b, axis=0))

    # 3. Perform DTW
    _, path = dtw(mfcc_a_normalised, mfcc_b_normalised)

    # 4. Align audio_b using DTW path
    mfcc_b_aligned = align_mfcc(mfcc_a_normalised, mfcc_b, path)
    audio_b_aligned = librosa.feature.inverse.mfcc_to_audio(np.einsum("ij->ji", mfcc_b_aligned))

    # 5. Export
    sf.write(f'./{audio_file_2}_aligned.wav', audio_b_aligned, sr_b)
    print(f"Aligned audio saved as '{audio_file_2}_aligned.wav'.")

    return

def naive_cut(audio_file_1, audio_file_2):
    audio_a, _ = load_audio(audio_file_1)
    audio_b, sr_b = load_audio(audio_file_2)
    sf.write('./audio_b_cut.wav', audio_b[:len(audio_a)], sr_b)
    print("Aligned audio saved as 'audio_b_cut.wav'.")

def naive_speed(audio_file_1, audio_file_2):
    audio_a, sr_a = load_audio(audio_file_1)
    audio_b, _ = load_audio(audio_file_2)
    sf.write('./audio_b_speed.wav', audio_b, int(sr_a*len(audio_b)/len(audio_a)))
    print("Aligned audio saved as 'audio_b_speed.wav'.")

In [None]:
def view_spectrogram(spectrogram, title="Mel Spectrogram", n_mels=80):
    if isinstance(spectrogram, np.ndarray):
        spectrogram = torch.tensor(spectrogram)
    if spectrogram.shape[0] != 80:
        spectrogram = torch.einsum("ij->ji", spectrogram)
    assert spectrogram.shape[0] == n_mels, f"spectrogram shape {spectrogram.shape} != ({n_mels}, seq_length)"
    print(spectrogram.shape)
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(spectrogram, sr=22050, x_axis='time', y_axis='mel', fmax=8000)
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    plt.show()

def get_spectrogram(file_name):

    signal, rate = torchaudio.load(file_name)
    signal = torchaudio.functional.resample(signal, orig_freq=rate, new_freq=22050)

    spectrogram, _ = mel_spectogram(
        audio=signal.squeeze(),
        sample_rate=22050,
        hop_length=256,
        win_length=None,
        n_mels=80,
        n_fft=1024,
        f_min=0.0,
        f_max=8000.0,
        power=1,
        normalized=False,
        min_max_energy_norm=True,
        norm="slaney",
        mel_scale="slaney",
        compression=True
    )

    return spectrogram

def get_spectrogram_from_waveform(signal, rate):
    
    if isinstance(signal, np.ndarray):
        signal = torch.tensor(signal, dtype=torch.float32)
    
    signal = torchaudio.functional.resample(signal, orig_freq=rate, new_freq=22050)

    spectrogram, _ = mel_spectogram(
        audio=signal.squeeze(),
        sample_rate=22050,
        hop_length=256,
        win_length=None,
        n_mels=80,
        n_fft=1024,
        f_min=0.0,
        f_max=8000.0,
        power=1,
        normalized=False,
        min_max_energy_norm=True,
        norm="slaney",
        mel_scale="slaney",
        compression=True
    )

    return spectrogram

def spectrogram_to_waveform(spectrogram, save_file_name):
    waveforms = hifi_gan.decode_batch(spectrogram) # spectrogram to waveform
    torchaudio.save(save_file_name, waveforms.squeeze(1), 22050)

def get_reconstructed_sample(file_name, save_file_name):

    signal, rate = torchaudio.load(file_name)
    signal = torchaudio.functional.resample(signal, orig_freq=rate, new_freq=22050)

    spectrogram, _ = mel_spectogram(
        audio=signal.squeeze(),
        sample_rate=22050,
        hop_length=256,
        win_length=None,
        n_mels=80,
        n_fft=1024,
        f_min=0.0,
        f_max=8000.0,
        power=1,
        normalized=False,
        min_max_energy_norm=True,
        norm="slaney",
        mel_scale="slaney",
        compression=True
    )

    waveforms = hifi_gan.decode_batch(spectrogram) # spectrogram to waveform

    torchaudio.save(save_file_name, waveforms.squeeze(1), 22050)

def transcript_to_audio(sentence, save_file_name):
    
    mel_output, mel_length, alignment = tacotron2.encode_text(sentence)
    # 1. Mel spectrogram with properties in the Tacotron paper (or see get_reconstructed_sample)
    #    Shape = (batch_size, n_mels=80, Mel_length + 1); Mel_length proportional to length of sequence
    # 2. Mel_length = mel_output.shape[2] - 1
    # 3. Alignment
    #    Shape = (batch_size, Mel_length, Token_length) where Token_length is from tacotron2.text_to_seq(txt)

    waveforms = hifi_gan.decode_batch(mel_output) # spectrogram to waveform

    torchaudio.save(save_file_name, waveforms.squeeze(1), 22050)

def transcript_to_mel(sentence):
    mel_output, mel_length, alignment = tacotron2.encode_text(sentence)
    return mel_output.squeeze() # remove the batch dimension

def mel_to_audio(mel_output, save_file_name=None, display=False):
    if isinstance(mel_output, np.ndarray):
        mel_output = torch.tensor(mel_output)
    if mel_output.shape[0] != 80:
        mel_output = torch.einsum("ij->ji", mel_output)
    waveforms = hifi_gan.decode_batch(mel_output) # spectrogram to waveform
    if save_file_name is not None: torchaudio.save(save_file_name, waveforms.squeeze(1), 22050)
    if display: return ipd.Audio(waveforms, rate=22050)
    return waveforms

def sample_audio(dataset, idx:int):
    print(dataset[idx])
    mel_to_audio(torch.einsum("ij->ji", dataset[idx]["data_mel"]), f"sample_{idx}.wav")
    mel_to_audio(torch.einsum("ij->ji", torch.tensor(dataset[idx]["original_data_mel"])), f"sample_{idx}_original.wav")

# Data and Pre-trained Models

In [None]:
tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts")
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="tmpdir_vocoder")

In [None]:
dataset = load_dataset("ylacombe/expresso")

In [None]:
dataset["train"][0] # visualise data

In [None]:
len(dataset["train"]) # 11615

In [None]:
class MelSpectrogramDataset(torch.utils.data.Dataset):
    def __init__(self, split, source, params):
        assert split in ("train", "valid", "test"), "invalid split"
        self.source = source.shuffle(seed=42) # for reproducibility
        self.base_pointer, self.limit_pointer = params[split]
        self.cache = dict()
        self.label_encoder = {
            'confused': 0,
            'default': 1,
            'emphasis': 2,
            'enunciated': 3,
            'essentials': 4,
            'happy': 5,
            'laughing': 6,
            'longform': 7,
            'sad': 8,
            'singing': 9,
            'whisper': 10
        }
        self.speaker_encoder = {
            'ex01': 0, 
            'ex02': 1,
            'ex03': 2,
            'ex04': 3
        }

    def __len__(self):
        return self.limit_pointer - self.base_pointer

    def __getitem__(self, idx):
        
        # 0. Preprocessing
        
        idx += self.base_pointer
        assert idx < self.limit_pointer, "index out of bounds"
        if idx in self.cache: return self.cache[idx] # memoisation
        
        item = self.source[idx]
        
        label = item["style"]
        speaker = item["speaker_id"]
        
        # 1. Obtain Mel Spectrograms
        
        data_mel_spectrogram = get_spectrogram_from_waveform(item["audio"]["array"], item["audio"]["sampling_rate"])
        ai_mel_spectrogram = transcript_to_mel(item["text"])
        
        data_mel_spectrogram = np.einsum("ij->ji", data_mel_spectrogram)
        ai_mel_spectrogram = np.einsum("ij->ji", ai_mel_spectrogram)
        
        assert ai_mel_spectrogram.shape[1] == 80
        assert data_mel_spectrogram.shape[1] == 80
        
        # 2. DTW
        dtw_cost, path = dtw(ai_mel_spectrogram, data_mel_spectrogram)
        aligned_to_ai_spectrogram = align(ai_mel_spectrogram, data_mel_spectrogram, path)
        
        assert aligned_to_ai_spectrogram.shape == ai_mel_spectrogram.shape, "DTW was not successful"
        assert aligned_to_ai_spectrogram.shape[1] == 80
        
        # 3. Duration modelling
        
        duration_arr = np.zeros(len(ai_mel_spectrogram))
        for i, (x, y) in enumerate(path):
            if i == 0:
                duration_arr[x] += 1
            else:
                xp, yp = path[i - 1]
                if yp == y:
                    duration_arr[xp] -= 1
                    duration_arr[x] += 1
                else:
                    duration_arr[x] += 1
        assert sum(duration_arr) == len(data_mel_spectrogram), "duration modelling not successful"
        
        # 4. Return AI Mel, aligned Data Mel, Emotion Label, Speaker Label, original Data Mel
        self.cache[idx] = {
            "ai_mel": torch.tensor(ai_mel_spectrogram), 
            "data_mel": torch.tensor(aligned_to_ai_spectrogram), 
            "label": torch.tensor([self.label_encoder[label]]), 
            "speaker": torch.tensor([self.speaker_encoder[speaker]]), 
            "original_data_mel": torch.tensor(data_mel_spectrogram),
            "sequence_length": torch.tensor([ai_mel_spectrogram.shape[0]]),
            "duration": torch.tensor(duration_arr),
            "text": item["text"],
            "data_audio": item["audio"]["array"],
            "data_sample_rate": item["audio"]["sampling_rate"],
            "ai_audio": mel_to_audio(ai_mel_spectrogram),
            "ai_sample_rate": 22050
        }
        return self.cache[idx]
    
    @staticmethod
    def collate(batch):
        
        assert torch.cuda.is_available()
        device = torch.device("cuda")
        
        ai_mel = pad_sequence(
            [item["ai_mel"] for item in batch],
            batch_first=True, padding_value=np.nan
        )
        data_mel = pad_sequence(
            [item["data_mel"] for item in batch],
            batch_first=True, padding_value=np.nan
        )
        duration = pad_sequence(
            [item["duration"] for item in batch],
            batch_first=True, padding_value=np.nan
        )
        labels = torch.cat(tuple([item["label"] for item in batch]))
        sequence_lengths = torch.cat(tuple([item["sequence_length"] for item in batch]))
        mask = torch.all(torch.where(torch.isnan(ai_mel), torch.full(ai_mel.shape, True), torch.full(ai_mel.shape, False)), 2)
        mask_check = torch.all(torch.where(torch.isnan(data_mel), torch.full(data_mel.shape, True), torch.full(data_mel.shape, False)), 2)
        mask_double_check = torch.where(torch.isnan(duration), torch.full(duration.shape, True), torch.full(duration.shape, False))
        assert torch.equal(mask, mask_check), "mask is dubious"
        assert torch.equal(mask, mask_double_check), f"mask is dubious {mask.shape}, {mask_double_check.shape}"
        
        batch_size = len(batch)
        _, ai_mel_max_length, _ = ai_mel.shape
        assert ai_mel.shape == (batch_size, ai_mel_max_length, 80)
        assert data_mel.shape == ai_mel.shape
        assert duration.shape == ai_mel.shape[:2]
        assert sequence_lengths.shape == torch.Size([batch_size])
        assert torch.all(sequence_lengths > 0), "not all sequence lengths are positive"
        assert mask.shape == ai_mel.shape[:2]
        
        return {
            "ai_mel": ai_mel.to(device),
            "data_mel": data_mel.to(device), 
            "labels": labels.to(device),
            "sequence_length": sequence_lengths.to(device),
            "mask": mask.to(device),
            "duration": duration.to(device)
        }

# Data Processing

In [None]:
preprocessed = MelSpectrogramDataset("train", dataset["train"], {
    "train": (0, 7000),
    "valid": (7000, 10000),
    "test": (10000, 11615)
})

In [None]:
def cache_dataset(source, params, save_file_name="alldata.pth"):
    train_dataset = MelSpectrogramDataset("train", source, params)
    train_ls = []
    for i in tqdm(range(len(train_dataset))):
        train_ls.append(train_dataset[i])
    torch.save({"training_data": train_ls}, f"train_{save_file_name}")
    print("finished train")
    
    validation_dataset = MelSpectrogramDataset("valid", source, params)
    valid_ls = []
    for i in tqdm(range(len(validation_dataset))):
        valid_ls.append(validation_dataset[i])
    torch.save({"validation_data": valid_ls}, f"valid_{save_file_name}")
    print("finished valid")
    
    test_dataset = MelSpectrogramDataset("valid", source, params)
    test_ls = []
    for i in tqdm(range(len(test_dataset))):
        test_ls.append(test_dataset[i])
    torch.save({"testing_data": test_ls}, f"test_{save_file_name}")
    print("finished test")
    
    torch.save({
        "training_data": train_ls, 
        "validation_data": valid_ls,
        "testing_data": test_ls
    }, save_file_name)
    print("DONE")

In [None]:
cache_dataset(dataset["train"], {
            "train": (0, 7000),
            "valid": (7000, 10000),
            "test": (10000, 11615)
}, save_file_name="alldata.pth")