### Batch synthesis with variation in speaker and on the feature axis.

In [None]:
import warnings
warnings.filterwarnings("ignore")
import matplotlib
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
%matplotlib notebook
import matplotlib.pylab as plt
import IPython.display as ipd
import sys
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import torchsummary
import csv
import os
from g2p_en import G2p
import re
import itertools as it
import pandas as pd

from tronduo.hparams import create_hparams
hparams = create_hparams()
hparams.global_mean = None
hparams.distributed_run = False
hparams.prosodic = True
hparams.speakers= True
hparams.feat_dim = 1
hparams.feat_max_bg = 8
hparams.n_speakers= 2
hparams.speaker_embedding_dim = 8
import librosa
import soundfile as sf

import os
import json
from hifigan.env import AttrDict
from hifigan.models import Generator
MAX_WAV_VALUE = 32768.0
device = 'cuda'
from tronduo.hifigandenoiser import Denoiser
import math

In [None]:
from tqdm.notebook import tqdm

In [None]:
from tronduo.model import Tacotron2
from tronduo.layers import TacotronSTFT, STFT
from scipy.io import wavfile
from tronduo.model_util import load_model
from tronduo import text_to_sequence
g2p = G2p()

### Load models

In [None]:
# Tensorflow models
#path = "output/feat_ljsryan/" # added f0 feature, normalised per speaker
#path = "output/jfeat_ljsryan/" # added f0 feature, normalised over full range
path = "models/tronduo/" # speaker vector
iter = "60000" # number of iterations trained in the folder above
outfolder = "syn/"

In [None]:
checkpoint_path = path + "checkpoint_" + iter
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
_ = model.cuda().eval().half()

In [None]:
def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict 

In [None]:
hfg_path = 'models/hifigan/'
checkpoint = 3100000
config_file = hfg_path + 'config.json'
checkpoint_file = hfg_path + 'g_' + str(checkpoint).zfill(8)
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
generator = Generator(h).to(device)
state_dict_g = load_checkpoint(checkpoint_file, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()

#from hifigandenoiser import Denoiser
denoiser = Denoiser(generator, mode='zeros') # The other mode is normal/zeros

### Text preparation

In [None]:
texts = ["Meanwhile spring had passed well into summer.",
    "They had come not to admire but to observe.",
    "There were other farmhouses nearby.",
    "Outside, only a handful of reporters remained.",
    "Coverage of primary literature will follow.",
    "The data are presented in lists and tables.",
    "A third volume remains to be published.",
    "At intervals an alumni directory is issued.",
    "Let us differentiate a few of these ideas.",
    "The system works as an impersonal mechanism."
        ]

In [None]:
filenames = ["si1905", "si1710", "si1664", "si1552", "si1445",
        "si1442", "si1440", "si1297", "si1182", "si1083"]

In [None]:
transcripts = [None]*len(texts)
txt = [re.sub('[\!.?]+','',tr) for tr in texts]
txt = [re.sub(';','.',tr) for tr in txt]
for i in range(len(texts)):
    phon = g2p(txt[i])
    for j, n in enumerate(phon):
        if n == ' ':
            phon[j] = '} {'
    transcripts[i] = '{ '+' '.join(phon)+' }.'
transcripts = [re.sub(r'(\s+){ , }(\s+)', ',', tr) for tr in transcripts]
transcripts = [re.sub(r'(\s+)?{ . }(\s+)?', ';', tr) for tr in transcripts]
#transcripts = [re.sub(r' ; ', ';', tr) for tr in transcripts]
transcripts = [re.sub(r'{ ', '{', tr) for tr in transcripts]
transcripts = [re.sub(r' }', '}', tr) for tr in transcripts]

In [None]:
print(transcripts)

### Load SGR tool

In [None]:
DEVICE = 'cuda:0'
def init_torch():
    torch.random.manual_seed(0)
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device)

    print(torch.__version__)
    print(torchaudio.__version__)
    print(device)
    return device

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(768, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 2)
        
    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
device = init_torch()
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
w2v2 = bundle.get_model().to(device)
w2v2.eval()

In [None]:
sgr2 = MLPClassifier()
sgr2.to(device)
sgr2.load_state_dict(torch.load('models/sgr/sgr2_model_1gpu.pt'))
sgr2.eval()

### Test a range of outcomes on a grid of inputs

In [None]:
results = pd.DataFrame(columns = ["Filename", "f", "m", "f0_in", "dur", 
                                  "utt", "pr_f", "pr_m", "gap"])
outfolder = "syn/"
out = True
if out and not os.path.exists(outfolder):
    os.makedirs(outfolder)
i = 0

In [None]:
# variation on speaker on a grid
for i in tqdm(range(2)):
    sequence = np.array(text_to_sequence(transcripts[i], ['english_cleaners']))[None, :]
    sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
    for j in np.arange(0.,1.01,0.1):
        for k in np.arange(0.,1.01,0.1):
            for m in [-0.2, 0.2]:
                speaks = torch.as_tensor([j,k]).unsqueeze(0).cuda()
                pros = torch.as_tensor(m).unsqueeze(0).half().cuda()
                #_, mel_outputs_postnet, _, _ = model.inference(sequence, speaks=speaks, pros=pros)
                durat = 1000
                while durat > 890:
                    try:
                        _, mel_outputs_postnet, _, _ = model.inference(sequence, speaks=speaks, pros=pros)
                        durat = mel_outputs_postnet[0].size()[1]
                    except:
                        pass
                melfl = mel_outputs_postnet.float()
                y_g_hat = generator(melfl)
                audio = denoiser(y_g_hat[0], strength=0.01).squeeze().half()
                audio_out = audio.cpu().detach().numpy()
                # generate output
                filename = f'{filenames[i]}_F{int(round(100*j,0)):04}M{int(round(100*k,0)):04}f0_{int(round(100*m,0)):03}'
                waveform = torchaudio.functional.resample(audio.unsqueeze(0).float(), 22050, bundle.sample_rate)
                feats, _ = w2v2.extract_features(waveform.clone().to(device))
                vec = torch.mean(feats[2], dim=1)
                outputs = sgr2(vec).detach().cpu().numpy()
                pf = 1 / (1+math.exp(-outputs[0][0]))
                pm = 1 / (1+math.exp(-outputs[0][1]))

                results = results.append({"Filename":filename, "f":100*j,"m":100*k,"f0_in":m,
                                          "dur":np.round(len(audio_out)/hparams.sampling_rate,3), 
                                          "utt":files[i],
                                         "pr_f":pf/(pf+pm), "pr_m":pm/(pf+pm), "gap":abs(pf-pm)/(pf+pm)},
                                        ignore_index = True)
                if out:
                    print(filename)
                    sf.write(outfolder + filename + '.wav',
                                        audio_out.astype('float32'), hparams.sampling_rate)

In [None]:
results.to_csv('test_results.csv', sep='|')

## Create grid to display results

In [None]:
import ipywidgets as widgets
from IPython.display import display, Audio, clear_output

In [None]:
outfolder = "syn/"
size = 9 # grid dimension
mid = 50 # grid midpoint for features (0 corresponds to 1st percentile in training, 100 to 99th percentile)
step = 10 # grid step size
utt = 1
f0 = 0.2
high = "100%"

In [None]:
%%html
<style>.myclass { border:1px solid white ; font-size:70% ; grid-row-gap:0px}

In [None]:
def play_sound(b):
    f = b.tooltip
    audios.clear_output(wait=True)
    with audios:
        display(Audio(filename=outfolder + f"{f}.wav", autoplay=True))

In [None]:
files = []
for i in np.arange(round(mid-(size-1)/2*step),round(mid+(size-1)/2*step)+0.01,step):
    for j in np.arange(round(i-(size-1)/2*step),round(i+(size-1)/2*step)+0.01,step):
        files.append(f'{filenames[utt]}_F{int(round(i)):04}M{int(round(j)):04}f0_{int(round(f0*100,0)):03}')

In [None]:
buttons = [widgets.Button(description="", tooltip=files[i], layout=widgets.Layout(width="100%", height=high), button_style='primary').add_class('myclass') for i in range(size*size)]
for i in range(size//2,size*size,size):
    buttons[i].button_style = 'info'
for i, button in enumerate(buttons):
    button.on_click(play_sound)

In [None]:
grid_out = widgets.Output(layout={"display": "flex", "flex_flow": "row wrap", "align_items": "flex-start", "margin": "0"})
for i in range(size):
    grid = widgets.GridBox([widgets.Button(description=f"{str(mid+(size//2-i)*step)}", layout=widgets.Layout(width="100%", height=high)).add_class('myclass')]\
                           +[widgets.Button(description="", layout=widgets.Layout(width="100%", height=high)).add_class('myclass')]*(size-1-i)\
                           +buttons[(size-1-i)*size:(size-i)*size]\
                           +[widgets.Button(description="", layout=widgets.Layout(width="100%", height=high)).add_class('myclass')]*i\
                           , layout=widgets.Layout(grid_template_columns=f"repeat({2*size},1fr)"), grid_gap='0px 0px')
    grid_out.append_display_data(grid)
# add x-axis
xax = [widgets.Button(description=f"{str(x)}", layout=widgets.Layout(width="100%", height=high)).add_class('myclass') for x in range(mid-(size-1)*step,mid+size*step,step)]
grid = widgets.GridBox([widgets.Button(description='F/M', layout=widgets.Layout(width="100%", height=high)).add_class('myclass')]\
                           +[widgets.Button(description="", layout=widgets.Layout(width="100%", height=high))]*(size-1-i)\
                           +xax, layout=widgets.Layout(grid_template_columns=f"repeat({2*size},1fr)"), grid_gap='0px 0px')
grid_out.append_display_data(grid)
display(grid_out)
audios = widgets.Output(layout={'border': '1px solid black'})
display(audios)