In [97]:
import samples
from huggingface_hub import hf_hub_download
import torch
import yaml
from espnet2.tasks.mlm import MLMTask
from argparse import Namespace
from IPython.display import Audio
import os
from espnet2.bin.sedit_inference import prepare_features, duration_path_dict, set_all_random_seed
import matplotlib.pyplot as plt
from espnet2.bin.sedit_inference import load_vocoder
import shutil
import soundfile
from utils import mcd_calculate
from glob import glob

In [53]:
# Load the model
chkpnt_file = hf_hub_download(repo_id="richardbaihe/a3t-vctk", filename="unseen_conformer/train.loss.ave_5best.pth")
chkpnt = torch.load(chkpnt_file)
config = yaml.safe_load(open(hf_hub_download(repo_id="richardbaihe/a3t-vctk", filename="unseen_conformer/config.yaml")))
preprocessor = MLMTask.build_preprocess_fn(args, train=False)
collate_fn = MLMTask.build_collate_fn(args, train=False)

args = Namespace(**config)
model = MLMTask.build_model(args)
model.load_state_dict(chkpnt)
model = model.eval()

['/private/home/mattle/a3t/samples/p225_001.flac',
 '/private/home/mattle/a3t/samples/p232_412.flac']

In [100]:
hop_length = config["feats_extract_conf"]["hop_length"]
wav_file = os.path.join(os.path.dirname(samples.__file__), "p232_412.flac")
gt_wav, sample_rate = soundfile.read(wav_file)
transcript = 'Senior management in Scotland threw its weight behind the orchestra.'
Audio(wav_file)

In [54]:
# Prepare the features
test_id = os.path.basename(wav_file).split('.')[0]
words = transcript.split()
split = max(len(words)//3,1)
new_str = " ".join(words[:split]+['[MASK]']+words[-split:])
duration_predictor_path = duration_path_dict['vctk']
batch,speech_lengths,old_span_boundary,new_span_boundary = prepare_features(
    model,preprocessor,wav_file,transcript,new_str,duration_predictor_path,test_id,mask_reconstruct=True
)
_, feats = collate_fn(batch)

In [75]:
# infer the spectrogram
with torch.no_grad():
    output = model.inference(**feats,span_boundary=new_span_boundary,use_teacher_forcing=True)["feat_gen"]
output = torch.cat([output[0].squeeze(0)]+ output[1:-1]+[output[-1].squeeze(0)], dim=0)

In [90]:
# Vocode the same spectrogram 10 times (without setting the seed)
vocoder = load_vocoder("vctk_parallel_wavegan.v1.long").eval()

for split in ["gt", "generated"]:
    if os.path.exists(f"samples/{split}"):
        shutil.rmtree(f"samples/{split}")
    os.makedirs(f"samples/{split}")

# Save only the masked region
start_idx, end_idx = map(lambda x: x * hop_length, new_span_boundary)
start_idx, end_idx
    
for i in range(10):
    wav = vocoder(output)
    soundfile.write(f"samples/generated/{test_id}_{i}.wav", wav[start_idx:end_idx], sample_rate)
    soundfile.write(f"samples/gt/{test_id}_{i}.wav", gt_wav[start_idx:end_idx], sample_rate)

In [99]:
# Compute the MCD
# https://github.com/richardbaihe/a3t/blob/dev_richard/aggregate_output/sedit_mcd.py#L38
opt = mcd_calculate.get_parser().parse_args([
    "--wavdir", "samples/generated",
    "--gtwavdir", "samples/gt",
    "--mcep_dim", "80",
    "--f0min", "80",
    "--f0max", "7600",
    "--shiftms", "300",
    "--silenced", "1"
])

results = []
gt_files = glob("samples/gt/*.wav")
files = glob("samples/generated/*.wav")
[mcd_calculate.calculate(files, gt_files, opt, results)]
results

[8.672639155544799,
 5.308921012564444,
 4.900816440561954,
 7.131769090844503,
 13.143600342966797,
 12.224695584920918,
 6.072500175332512,
 11.702555791173376,
 8.065890147005044,
 9.663162912389703]

In [101]:
# Note the wide range of MCD results for different vocodings of the *same* spectrogram
min(results), max(results)

(4.900816440561954, 13.143600342966797)

In [102]:
# Instead use --shiftms 10 and --f0max 1000

opt = mcd_calculate.get_parser().parse_args([
    "--wavdir", "samples/generated",
    "--gtwavdir", "samples/gt",
    "--mcep_dim", "80",
    "--f0min", "80",
    "--f0max", "1000",
    "--shiftms", "10",
    "--silenced", "1"
])

results = []
gt_files = glob("samples/gt/*.wav")
files = glob("samples/generated/*.wav")
[mcd_calculate.calculate(files, gt_files, opt, results)]
results

[9.86404666714688,
 9.9295087348438,
 9.901495293436984,
 10.069008521267204,
 9.92362636539008,
 9.804463382249361,
 9.911342836866577,
 10.105765569630515,
 9.756784830891704,
 10.003715515907786]

In [103]:
min(results), max(results)

(9.756784830891704, 10.105765569630515)