# Installs and Imports

## Basic Imports

In [None]:
import os
import sys
import IPython
import torch
import json
import pandas as pd
import numpy as np
from tqdm import tqdm

## Vocoder Imports

In [None]:
%%capture
!pip install parallel_wavegan
!pip install h5py=='3.6.0'
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

In [None]:
from parallel_wavegan.utils import load_model
from parallel_wavegan.utils import read_hdf5
from parallel_wavegan.bin.preprocess import logmelfilterbank
import yaml

## TTS Imports

In [None]:
sys.path.insert(1, '/workspace/coqui-tts')
from TTS.config import load_config, register_config
from TTS.tts.models import setup_model
from TTS.tts.models.forward_tts import ForwardTTS
from TTS.tts.models.styleforward_tts import StyleforwardTTS
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.styles import StyleManager
from TTS.tts.utils.visual import plot_spectrogram
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples

## UMAP Imports

In [None]:
%%capture
! pip install umap-learn
import umap
import warnings
warnings.filterwarnings('ignore')

# Load Model and Set Device

In [None]:
%%capture

# MODEL AND DEVICE SELECTION
run_name = "re+class"           
run_select = "last_checkpoint"
device = "cpu"
synthesizer = "hifi-gan" # hifi-gan or griffin-limm
use_cuda = True if device == "cuda" else False

# CHECKPOINTS
checkpoints_dict = {
                   # Neutral Models (Speaker Look-Up)
                   "vctk":{"last_checkpoint":"model_file.pth.tar"},
                   "neutral":{"last_checkpoint":"checkpoint_1080000.pth.tar", "part":"neutral"},
    
                   # Style Finetunings (Speaker Look-Up)
                   "amused":{"last_checkpoint":"checkpoint_1113000.pth.tar", "part":"amused"},
                   "angry":{"last_checkpoint":"checkpoint_1121000.pth.tar", "part":"angry"},
                   "disgusted":{"last_checkpoint":"checkpoint_1097000.pth.tar", "part":"disgusted"},
                   "sleepy":{"last_checkpoint":"checkpoint_1036000.pth.tar", "part":"Sleepy"},
    
                   # Speaker Finetunings (Style Look-Up)
                   "sam":{"last_checkpoint": "checkpoint_1390000.pth.tar", "part":"sam"},
                   "josh":{"last_checkpoint": "checkpoint_1150000.pth.tar", "part":"josh"},
                   "jenie":{"last_checkpoint": "checkpoint_1160000.pth.tar", "part":"jenie"},
                   "bea":{"last_checkpoint": "checkpoint_1130000.pth.tar", "part":"bea"},
    
                   # Multi-Speaker Multi-Style (Double Look-Up)
                   "double_lookup":{"last_checkpoint": "checkpoint_1240000.pth.tar", "part":"all"},
    
                   # Representation Learning
                   "re":{"last_checkpoint": "checkpoint_1240000.pth.tar", "part":"all"},
                   "re+class":{"last_checkpoint": "checkpoint_1240000.pth.tar", "part":"all"},
                    }

# EXPERIMENT FOLDER
folder = "../experiments/" + run_name + "/"

# LOAD CONFIG
config = load_config(folder + "config.json")

# LOAD SPEAKERS
if os.path.isfile(folder + "speakers.json"):
    spk_file_path = folder + "speakers.json"
    spk_manager = SpeakerManager(speaker_id_file_path = spk_file_path)
    with open(spk_file_path) as json_file:
        spk_to_id = json.load(json_file)
else:
    spk_to_id = {}
    spk_manager = None
    
# LOAD STYLES
if os.path.isfile(folder + "style_ids.json"):
    sty_file_path = folder + "style_ids.json"
    sty_manager = StyleManager(style_ids_file_path = sty_file_path)
    with open(sty_file_path) as json_file:
        sty_to_id = json.load(json_file)
else:
    sty_to_id = {}
    sty_manager = None
    
# LOAD MODEL        
model = setup_model(config, speaker_manager = spk_manager, style_manager = sty_manager)

# LOAD CHECKPOINT
checkpoint = torch.load(folder + checkpoints_dict[run_name][run_select], map_location=torch.device(device))['model']
model.load_state_dict(checkpoint)

# PREPARE MODEL
model.to(device)
model.eval()

# PREPARE VOCODER
if synthesizer != "griffin-limm":
    voc_name = synthesizer
    voc_checkpoint = "../experiments/vocoders/" + voc_name + "/checkpoint-470000steps.pkl"
    voc_config_path = "../experiments/vocoders/" + voc_name + "/config.yml"
    voc_stats = "../experiments/vocoders/" + voc_name + "/stats.h5"
    with open(voc_config_path) as f:
        voc_config = yaml.load(f, Loader = yaml.Loader)
    vocoder = load_model(voc_checkpoint, voc_config)
    vocoder.to(device)
    vocoder.eval()
    vocoder.remove_weight_norm()
    
    # CHECK COMPATIBILITY WITH TTS
    config.audio.log_func = 'np.log10'
    assert voc_config['sampling_rate'] == config.audio.sample_rate
    assert voc_config['fmax'] == config.audio.mel_fmax
    assert voc_config['fmin'] == config.audio.mel_fmin
    assert voc_config['fft_size'] == config.audio.fft_size
    assert voc_config['hop_size'] == config.audio.hop_length
    assert voc_config['win_length'] == config.audio.win_length
else:
    pass


# Audio Processor
ap = AudioProcessor(**config.audio)

# Audios

Outputs audios of the synthesis, ressynthesis and ground-truths.

In [None]:
def read_emovdb_metadata(dataset, sub):
    # Get csv
    df = pd.read_csv("../recipes/emovdb/emovdb/" + "metadata/metadata_"+ dataset + "_" + sub + ".csv", sep = "\n", header=None)
    lines = [item for sublist in df.values.tolist() for item in sublist]
    # Parse csv
    file_names = []
    texts = []
    spks = []
    styles = []
    for line in lines:
        file_names.append('../recipes/emovdb/emovdb/files/'+ line.split(sep='|')[0])
        texts.append(line.split(sep='|')[1])
        spks.append(line.split(sep='|')[2])
        styles.append(line.split(sep='|')[3])
    return {'texts':texts, 'speakers':spks, 'styles':styles, 'style_wavs':file_names}

## Synthesis

In [None]:
# Fetch Inputs
dataset = "test"
partition = checkpoints_dict[run_name]["part"]
data = read_emovdb_metadata(dataset,partition)
idx = 0

# Or Insert directly
text = data['texts'][idx]
speaker = data['speakers'][idx]
style = data['styles'][idx]
style_wav = data['style_wavs'][idx]
style_representation = style_representation if style_representation else None

# SYNTHESIS
if synthesizer == "griffin-limm":
    out = synthesis(use_griffin_lim=True, text = text, speaker_id = spk_to_id.get(speaker), style_id = sty_to_id.get(style), style_wav = style_wav, style_representation = style_representation, model = model, CONFIG = config, use_cuda = use_cuda, ap = ap)
else:
    fp_out = synthesis(text = text, speaker_id = spk_to_id.get(speaker), style_id = sty_to_id.get(style), style_wav = style_wav, style_representation = style_representation, model = model, CONFIG = config, use_cuda = use_cuda, ap = ap)
    feat_gen_denorm = torch.Tensor(fp_out['outputs']['model_outputs'].cpu().numpy()[0]).cuda()
    feat_gen_denorm = torch.log10(torch.exp(feat_gen_denorm))
    feat_gen_norm = (feat_gen_denorm.cpu() - torch.from_numpy(read_hdf5(voc_stats, "mean"))) / torch.from_numpy(read_hdf5(voc_stats, "scale"))
    out = {'wav':vocoder.inference(feat_gen_norm).cpu().detach().numpy().squeeze(1)}

# RESULTS
audio_syn = out['wav']
print("Text = {}".format(text))
print("Spk = {}".format(speaker))
print("Sty = {}".format(style))
print("Sty_Wav = {}".format(style_wav))
print("Sty_Rep = {}".format(style_representation))
IPython.display.Audio(audio_syn, rate=config.audio.sample_rate)

## GT Resynthesis

In [None]:
gt_wav = ap.load_wav(data['style_wavs'][idx])
gt_spectrogram = ap.melspectrogram(gt_wav)

if synthesizer == 'griffin-limm':
    res_wav = ap.inv_melspectrogram(gt_spectrogram)
else:
    feat_gen_denorm = torch.Tensor(gt_spectrogram.T).cuda()
    feat_gen_norm = (feat_gen_denorm.cpu() - torch.from_numpy(read_hdf5(voc_stats, "mean"))) / torch.from_numpy(read_hdf5(voc_stats, "scale"))
    res_wav = vocoder.inference(feat_gen_norm).cpu().detach().numpy().squeeze(1)
res_spectrogram = ap.melspectrogram(res_wav)

# PLAY GT+RES
IPython.display.Audio(res_wav, rate=config.audio.sample_rate)

## GT

In [None]:
IPython.display.Audio(gt_wav, rate=config.audio.sample_rate)

# Mel-Spectrograms

Outputs images of mel-spectrograms.

## Synthesis

In [None]:
# PLOT GT MEL-SPECTROGRAM
plot_spectrogram(ap.melspectrogram(audio_syn).T, ap, fig_size=(8,3))

## GT Resynthesis

In [None]:
# PLOT GT+RES MEL-SPECTROGRAM
plot_spectrogram(res_spectrogram.T, ap, fig_size=(8,3))

## GT

In [None]:
# PLOT GT MEL-SPECTROGRAM
plot_spectrogram(gt_spectrogram.T, ap, fig_size=(8,3))

# UMAP

Outputs style spaces and style representations.

## Setup

In [None]:
#### FUNCTIONS ####

def numpy_to_torch(np_array, dtype, cuda=False):
    if np_array is None:
        return None
    tensor = torch.as_tensor(np_array, dtype=dtype)
    if cuda:
        return tensor.cuda()
    return tensor

def id_to_torch(speaker_id, cuda=False):
    if speaker_id is not None:
        speaker_id = np.asarray(speaker_id)
        speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
    if cuda:
        return speaker_id.cuda().type(torch.long)
    return speaker_id.type(torch.long)


def compute_style_mel(style_wav, ap, cuda=False):
    style_mel = torch.FloatTensor(ap.melspectrogram(
        ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
    if cuda:
        return style_mel.cuda()
    return style_mel


#### DATASET DEFINITIONS ####

config.datasets[0].path = '../recipes/emovdb/emovdb'
test = '../recipes/emovdb/emovdb/metadata/metadata_test_all.csv'
test = pd.read_csv(test, delimiter='|', encoding= 'utf-8', header=None, names = ['wav_path', 'text', 'speaker', 'style']) 
test.head()

# dict style2id
map_style = {
    'Amused': 0,
    'Angry': 1,
    'Disgusted': 2,
    'Neutral': 3,
    'Sleepy': 4
}

# dict style2id
map_id2style = {
    0:'Amused',
    1:'Angry',
    2:'Disgusted',
    3:'Neutral',
    4:'Sleepy' 
}

def map_wavpath2style(wav_path):
    if('Amused' in wav_path):
        return 'Amused'
    elif('Angry' in wav_path):
        return 'Angry'
    elif('Disgusted' in wav_path):
        return 'Disgusted'
    elif('Neutral' in wav_path):
        return 'Neutral'
    elif('Sleepy' in wav_path):
        return 'Sleepy'
    else:
        return 'none'


#### TRAINER ####

# init trainer args
train_args = TrainingArgs()

# load training samples
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True)

# init speaker manager
if config.use_speaker_embedding:
    speaker_manager = SpeakerManager(data_items=train_samples + eval_samples)
elif config.use_d_vector_file:
    speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
else:
    speaker_manager = None

# init style manager
if config.style_encoder_config.use_supervised_style:
    style_manager = StyleManager(data_items=train_samples + eval_samples)
    if hasattr(config, "model_args"):
        config.model_args.num_styles = style_manager.num_styles
    else:
        config.num_styles = style_manager.num_styles
else:
    style_manager = None
    
# init the model from config
language_manager = None
model = setup_model(config, speaker_manager, language_manager, style_manager)

# init the trainer
trainer = Trainer(
    train_args,
    config,
    config.output_path,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
    training_assets={"audio_processor": ap},
    parse_command_line_args=False,
)

# restore checkpoint
checkpoint = folder + checkpoints_dict[run_name]['last_checkpoint']
trainer.model, opt, scaler, restore_step = trainer.restore_model(config, checkpoint, trainer.model, trainer.optimizer, trainer.scaler)

# extract representations
use_cuda = True

N = config['style_encoder_config']['style_embedding_dim']

train_feats = np.zeros((len(test), N))
valid_feats = np.zeros((len(eval_samples), N))

styles = []

for i in tqdm(range(len(test))):
    style_wav = '../recipes/emovdb/emovdb/files/' + test.wav_path.values[i]
    style_mel = compute_style_mel(style_wav, ap, cuda=True)
    style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)[0].T

    o_en, outputs = model.cuda().style_encoder_layer.forward([torch.rand(1,384,1).cuda(),style_mel.unsqueeze(0)], None)

    if(config['style_encoder_config']['se_type'] == 'vae'):
        outputs = outputs['z']
    elif(config['style_encoder_config']['se_type'] == 'diffusion'):
        outputs = outputs['style']

    train_feats[i] = outputs.squeeze(0).squeeze(0).detach().cpu().numpy()
    styles.append(map_style[map_wavpath2style(test.wav_path.values[i])])

## Style Space

In [None]:
# FIT / TRANSFORM UMAP

#u = umap.UMAP(random_state = 42)

#embeddings = u.fit_transform(train_feats)
embeddings = u.transform(train_feats) # when the umap is already trained

# Creating dataframe to better manipulate
df = pd.DataFrame({'style': styles, 'dim1': embeddings[:,0], 'dim2': embeddings[:,1]})

# Plot
df['style_id'] = df['style'].map(map_style)
df.head(), df.tail()
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(12,5))
for i in range(5):
    df_filt = df[df['style'] == i]
    plt.scatter(df_filt['dim1'], df_filt['dim2'], label = map_id2style[i])
plt.legend(fontsize=15)
plt.xlabel('UMAP dim 0', fontsize = 20)
plt.ylabel('UMAP dim 1', fontsize = 20)

plt.show()

## Extract Centroid Representation

In [125]:
style_centroid = "Amused"
idxs = df[df['style']== map_style[style_centroid]].index.tolist()
style_representations = train_feats[idxs]
centroid = style_representations.mean(axis=0)
print(u.transform(centroid.reshape(1,-1)))
style_representation = torch.Tensor(centroid).unsqueeze(0).to('cuda:0')
# HOW TO GET RANDOM REPRESENTATION