In [None]:
import base64
import glob
import io
import os
import pickle
from typing import Dict, List

from IPython.core.display import HTML
from IPython.display import Audio, display
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial import distance
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
import torch

from models.helpers.tools import get_mask_from_lengths
from models.tts.delightful_tts.delightful_tts_refined import DelightfulTTS
from models.vocoder.univnet import UnivNet
from training.datasets import LibriTTSDatasetAcoustic
from training.loss import FastSpeech2LossGen, Metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Add the checkpoints to the `/checkpoints` folder and choose the appropriate version

In [None]:
# Checkpoints to use for demo
checkpoint_name = "epoch=60-step=13908"
checkpoint_univnet = "vocoder_epoch=19-step=9680"


# checkpoint_path = f"./logs_100/{checkpoint_name}.ckpt"
# checkpoint_path = f"./logs_360_energy/{checkpoint_name}.ckpt"
# checkpoint_path = f"./logs/{checkpoint_name}.ckpt"
checkpoint_path = f"./checkpoints/{checkpoint_name}.ckpt"
checkpoint_path_univnet = f"./checkpoints/{checkpoint_univnet}.ckpt"

# Dataset url to use for demo
# dataset_url = "train-clean-100"
dataset_url = "train-clean-360"

### Load from the checkpoint

In [None]:
model = DelightfulTTS.load_from_checkpoint(checkpoint_path, strict=False).to(device)
model.eval()

print(f"Loaded checkpoint: {checkpoint_path}")

In [None]:
# univnet = UnivNet.load_from_checkpoint(checkpoint_path_univnet, strict=False).to(device)
# univnet.eval()

# print(f"Loaded checkpoint: {checkpoint_path_univnet}")

## Create an index file for the dataset or load from the cache

In [None]:
sample_rate = 22050

# File path
file_path = f"./datasets_cache/{dataset_url}_index_dict.pkl"

# Create the index dictionary
dataset = LibriTTSDatasetAcoustic(url = dataset_url)

# Check if the file exists
if os.path.exists(file_path):
    # Load the index dictionary from disk
    with open(file_path, "rb") as f:
        index_dict = pickle.load(f)
else:
    index_dict = {item: index for index, item in enumerate(dataset.dataset._walker)}

    # Save the index dictionary to disk
    with open(file_path, "wb") as f:
        pickle.dump(index_dict, f)

# Length of the index dictionary that maps file names to indices
len(index_dict)

## Make a datastructure with one longest audio per author then add the quality metrics

In [15]:
def get_data(filename):
    idx = index_dict[os.path.splitext(filename)[0]]
    data = dataset[idx]
    return data


def load_text(file_path):
    with open(file_path) as file:
        text = file.read()
    return text


def create_audio(src: str, mimetype: str = "audio/wav") -> str:
    return f"""<audio controls="controls"><source src="{src}" type="{mimetype}" />Your browser does not support the audio element.</audio>"""


def audio_for_speaker_libri(dataset = dataset_url):
    audio_dir = f"./datasets_cache/LIBRITTS/LibriTTS/{dataset}"

    chapter_per_author = [
        (
            speaker_dir,
            os.path.join(
                audio_dir,
                speaker_dir,
                os.listdir(os.path.join(audio_dir, speaker_dir))[0],
            ),
        )
        for speaker_dir in os.listdir(audio_dir)
    ]

    audios_for_author = {
        speaker: sorted(
            glob.glob(os.path.join(audio_dir, "*.wav")),
            key=os.path.getsize,
        )[-1]
        for speaker, audio_dir in chapter_per_author
    }

    def create_audio_and_text(speaker, audio_path):
        return {
            "READER": speaker,
            "AUDIO": audio_path, # create_audio(audio_path),
            **get_data(os.path.basename(audio_path)),
        }

    audios_and_text = Parallel(n_jobs=-1)(
        delayed(create_audio_and_text)(speaker, audio_path)
        for speaker, audio_path in audios_for_author.items()
    )

    return audios_and_text


def prepared_dataset_subset(dataset = dataset_url):
    # Filter speakers by the selected subset
    speakers_df = pd.read_csv(
        "./datasets_cache/LIBRITTS/LibriTTS/speakers.tsv",
        sep="\t",
        names=["READER", "GENDER", "SUBSET", "NAME"],
    )
    selected_speakers_subset = speakers_df[speakers_df["SUBSET"] == dataset]

    audio_example_for_author = audio_for_speaker_libri(dataset)

    # Convert the dictionary to a DataFrame
    audio_example_df = pd.DataFrame(audio_example_for_author)
    selected_speakers_subset = pd.merge(selected_speakers_subset, audio_example_df, on="READER")
    selected_speakers_subset["READER"] = selected_speakers_subset["READER"].astype(int)

    return selected_speakers_subset


### Save/load cached data from the dataset

In [None]:
# File path
file_path = f"./datasets_cache/{dataset_url}_energy.pkl"

# Check if the file exists
if os.path.exists(file_path):
    # Load the data back
    with open(file_path, "rb") as f:
        example_demo: List[Dict] = pickle.load(f)
else:
    # Prepare the dataset subset
    example_demo_ = prepared_dataset_subset(dataset_url)
    example_demo = example_demo_.to_dict("records")

    # Serialize and save the data
    with open(file_path, "wb") as f:
        pickle.dump(example_demo, f)

len(example_demo)

## Add metrics and losses to the data

In [None]:
# File path
file_path = f"./datasets_cache/{dataset_url}_with_metrics_{checkpoint_name}_energy.pkl"

# Check if the file exists
if os.path.exists(file_path):
    # Load the data back
    with open(file_path, "rb") as f:
        example_demo_with_metrics = pickle.load(f)
else:
    loss = FastSpeech2LossGen()
    metrics = Metrics()

    example_demo_with_metrics = []

    for row in example_demo:
        batch = [
            r.to(device) if isinstance(r, torch.Tensor) else r
            for r in dataset.collate_fn([row])
        ]

        (
            _,
            _,
            speakers,
            texts,
            src_lens,
            mels,
            pitches,
            _,
            mel_lens,
            langs,
            attn_priors,
            _,
            energies,
        ) = batch

        src_mask = get_mask_from_lengths(src_lens.float())
        mel_mask = get_mask_from_lengths(mel_lens.float())

        with torch.no_grad():
            output = model.acoustic_model.forward_train(
                x=texts,
                speakers=speakers,
                src_lens=src_lens,
                mels=mels,
                mel_lens=mel_lens,
                pitches=pitches,
                langs=langs,
                attn_priors=attn_priors,
                energies=energies,
            )

        y_pred = output["y_pred"]
        log_duration_prediction = output["log_duration_prediction"]
        p_prosody_ref = output["p_prosody_ref"]
        p_prosody_pred = output["p_prosody_pred"]
        pitch_prediction = output["pitch_prediction"]

        model.vocoder_module.to(mels.device)

        wav_prediction = model.vocoder_module.forward(y_pred.float())
        wav_original = model.vocoder_module.forward(mels.float())

        energy_pred = output["energy_pred"]
        energy_target = output["energy_target"]

        (
            total_loss,
            mel_loss,
            ssim_loss,
            duration_loss,
            u_prosody_loss,
            p_prosody_loss,
            pitch_loss,
            ctc_loss,
            bin_loss,
            energy_loss,
        ) = loss(
            src_masks=src_mask,
            mel_masks=mel_mask,
            mel_targets=mels,
            mel_predictions=y_pred,
            log_duration_predictions=log_duration_prediction,
            u_prosody_ref=output["u_prosody_ref"],
            u_prosody_pred=output["u_prosody_pred"],
            p_prosody_ref=p_prosody_ref,
            p_prosody_pred=p_prosody_pred,
            pitch_predictions=pitch_prediction,
            p_targets=output["pitch_target"],
            durations=output["attn_hard_dur"],
            attn_logprob=output["attn_logprob"],
            attn_soft=output["attn_soft"],
            attn_hard=output["attn_hard"],
            src_lens=src_lens,
            mel_lens=mel_lens,
            energy_pred=energy_pred,
            energy_target=energy_target,
            step=50000,
        )

        metrics_logs = metrics(
            wav_prediction, wav_original, y_pred, mels,
        )

        example_demo_with_metrics.append({
            "ID": row["id"],
            "TEXT": row["normalized_text"],
            "READER": row["READER"],
            "GENDER": row["GENDER"],
            "SUBSET": row["SUBSET"],
            "NAME": row["NAME"],
            "AUDIO": row["AUDIO"],

            # Mel spectrograms
            "mels": mels.detach().cpu().numpy(),
            "mels_prediction": y_pred.detach().cpu().numpy(),

            # Waveforms
            "wav_original": wav_original.detach().cpu().numpy(),
            "wav_prediction": wav_prediction.detach().cpu().numpy(),

            # Losses
            "total_loss": total_loss.item(),
            "mel_loss": mel_loss.item(),
            # "sc_mag_loss": sc_mag_loss.item(),
            # "log_mag_loss": log_mag_loss.item(),
            "ssim_loss": ssim_loss.item(),
            "duration_loss": duration_loss.item(),
            "u_prosody_loss": u_prosody_loss.item(),
            "p_prosody_loss": p_prosody_loss.item(),
            "pitch_loss": pitch_loss.item(),
            "ctc_loss": ctc_loss.item(),
            "bin_loss": bin_loss.item(),
            "energy_loss": energy_loss.item(),

            # Metrics
            "si_sdr": metrics_logs.si_sdr.item(),
            "si_snr": metrics_logs.si_snr.item(),
            "c_si_snr": metrics_logs.c_si_snr.item(),
            "energy": metrics_logs.energy.item(),
            "mcd": metrics_logs.mcd.item(),
            "spec_dist": metrics_logs.mcd.item(),
            "f0_rmse": metrics_logs.f0_rmse,
            "jitter": metrics_logs.jitter,
            "shimmer": metrics_logs.shimmer,
        })

    # Serialize and save the data
    with open(file_path, "wb") as f:
        pickle.dump(example_demo_with_metrics, f)

len(example_demo_with_metrics)

### Choose metrics for clustering and scale them

In [None]:
scaler = StandardScaler()

losses = [
    "ctc_loss",
]
metrics = [
    "si_sdr",
    "si_snr",
    "c_si_snr",
    "energy",
    "mcd",
    "spec_dist",
    "f0_rmse",
]

clusters_metrics = losses + metrics

# Prepare the data
X = np.array([
    [data[metric] for metric in clusters_metrics]
    for data in example_demo_with_metrics
])

X = np.nan_to_num(X, nan=np.inf, posinf=1e-6, neginf=-1e-6)

X_scaled = scaler.fit_transform(X)
X_scaled[:1]

## Find the optimal amount of clusters
### Check the plots bellow and choose the hyperparameter K

In [None]:
# The Elbow Method

distortions = []
K = range(1, 10)
for k in K:
    kmeanModel = KMeans(n_clusters=k, n_init=10)
    kmeanModel.fit(X_scaled)
    distortions.append(kmeanModel.inertia_)

plt.figure(figsize=(16,8))
plt.plot(K, distortions, "bx-")
plt.xlabel("k")
plt.ylabel("Distortion")
plt.title("The Elbow Method showing the optimal k")
plt.show()

In [None]:
# The Silhouette Method

sil = []
K = range(2, 10)
# minimum 2 clusters required to calculate silhouette score
for k in K:
  kmeans = KMeans(n_clusters=k, n_init=10).fit(X_scaled)
  labels = kmeans.labels_
  sil.append(silhouette_score(X_scaled, labels, metric = "euclidean"))

plt.figure(figsize=(16,8))
plt.plot(K, sil, "bx-")
plt.xlabel("k")
plt.ylabel("Silhouette Score")
plt.title("Silhouette Method showing the optimal k")
plt.show()

In [None]:
# Set the number of clusters here K
K = 5

# Fit the model
kmeans = KMeans(n_clusters=K, random_state=0).fit(X_scaled)

# Print the cluster centers
# cluster_centers = scaler.inverse_transform(kmeans.cluster_centers_)
# print(f'Cluster centers: {cluster_centers}')

# Add the cluster labels to your data
for data, label, X_s in zip(example_demo_with_metrics, kmeans.labels_, X_scaled):
    data["cluster"] = label

    center = kmeans.cluster_centers_[label]
    # Calculate the distance from the cluster center
    data["distance"] = distance.euclidean(X_s, center)

example_demo_with_metrics[:1]

## Add `gen_table` function that draw the cluster with audio, text and metrics

In [21]:
metrics_ins = Metrics()

def gen_table(selected_speakers: list):
    # selected_speakers_subset = selected_speakers_subset.copy()
    # selected_speakers_subset['READER'] = selected_speakers_subset['READER'].astype(int)

    # selected_speakers = selected_speakers_subset.to_dict('records')

    # Initialize an empty string to store the HTML
    html = "<table border='1'>"

    html += "<h4>Speakers: </h4>"
    html += "<tr><th>SpeakerID</th><th>Speaker Name</th><th>Audio</th><th>Pred Audio</th><th>Text</th><th>si_sdr</th><th>si_snr</th><th>ctc_loss</th><th>energy</th><th>spec</th></tr>"

    for row in selected_speakers:
        speaker_id = row["READER"]
        speaker_name = row["NAME"]
        orig_mel = row["mels"]
        pred_mel = row["mels_prediction"]

        # Waveforms
        wav_original = row["wav_original"]
        wav_prediction = row["wav_prediction"]

        text = row["TEXT"]

        # Metrics
        si_sdr = row["si_sdr"]
        si_snr = row["si_snr"]
        ctc_loss = row["ctc_loss"]
        energy = row["energy"]

        # Round the metrics to 3 decimal places
        metrics = [round(x, 3) for x in [si_sdr, si_snr, ctc_loss, energy]]

        audio = Audio(wav_original, rate=sample_rate, autoplay=False)

        # Generate the spectrogram plot
        fig = metrics_ins.plot_spectrograms(
            orig_mel[0],
            pred_mel[0],
        )

        # Save the plot to a BytesIO object
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)

        # Convert the BytesIO object to a base64 string
        img_str = base64.b64encode(buf.read()).decode("utf-8")

        # Add a row to the HTML table
        html += "<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td><img src='data:image/png;base64,{}' /></td></tr>".format(
            speaker_id, speaker_name, audio._repr_html_(),
            Audio(wav_prediction, rate=sample_rate)._repr_html_(), text,
            *metrics, img_str,
        )

        # Close the figure to free up memory
        plt.close(fig)

    # Close the HTML table
    html += "</table>"

    return HTML(html)


### One row of the data 

In [22]:
df = pd.DataFrame(example_demo_with_metrics)
# Order by energy, ascending
gen_table(df.sort_values(by=["energy"], ascending=False)[:1].to_dict("records"))

SpeakerID,Speaker Name,Audio,Pred Audio,Text,ctc_loss,ssim_loss,mel_loss,Unnamed: 8
1061,Missie,Your browser does not support the audio element.,Your browser does not support the audio element.,"They had intended to name the baby Lucy, if it were a girl; but they hadn't expected her on Christmas morning, and a real Christmas baby was not to be lightly named-the whole family agreed in that.",1,1,1,1


## Show the clusters data

In [None]:
# Get the unique cluster labels
clusters = df["cluster"].sort_values().unique()

result = ""

# Loop over the clusters
for cluster in clusters:
    # Filter the DataFrame for the current cluster and sort by distance
    df_cluster = df[df["cluster"] == cluster].sort_values(by=["distance"], ascending=True)

    # Select the relevant columns
    df_cluster_ = df_cluster[["READER", "distance"] + clusters_metrics]

    # Add header
    result += f"<h3>Cluster #{cluster}: </h3>"

    # Describe the cluster
    result += df_cluster_.describe().to_html()

    result += f"<h3>Cluster #{cluster} audio data: </h3>"
    # Generate and display the table for the first 5 rows
    result += gen_table(df_cluster[:10].to_dict("records")).data # type: ignore

# Save result as HTML
with open(f"logs/output_{dataset_url}_{checkpoint_name}.html", "w") as f:
    f.write(result)

# Display the result
display(HTML(result))