In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json

import scipy
import torch
import ppgs

import promonet

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
gpu = 0
device = 'cpu' if gpu is None else f'cuda:{gpu}'

### Pitch and WER evaluation

In [4]:
configs = [
    'bottleneck',
    'encodec',
    'mel',
    'w2v2fb',
    'w2v2fc',
    'bottleneck-latent',
    'encodec-latent',
    'mel-latent',
    'w2v2fb-latent',
    'w2v2fc-latent']

In [5]:
pitch_results, wer_results, jsd_results = {}, {}, {}
for config in configs:
    with open(promonet.RESULTS_DIR / config / 'results.json') as file:
        result = json.load(file)
    pitch_results[config] = .5 * (result['shifted-089']['pitch'] + result['shifted-112']['pitch'])
    wer_results[config] = .5 * (result['shifted-089']['wer'] + result['shifted-112']['wer'])
    jsd_results[config] = .5 * (result['shifted-089']['ppg'] + result['shifted-112']['ppg'])
print('Pitch', json.dumps(pitch_results, indent=4, sort_keys=True))
print('WER', json.dumps(wer_results, indent=4, sort_keys=True))
print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))

Pitch {
    "bottleneck": 61.91497966647148,
    "bottleneck-latent": 53.125523775815964,
    "encodec": 56.40435889363289,
    "encodec-latent": 184.4439446926117,
    "mel": 53.10908630490303,
    "mel-latent": 207.25626647472382,
    "w2v2fb": 56.69333338737488,
    "w2v2fb-latent": 52.99616530537605,
    "w2v2fc": 58.27789306640625,
    "w2v2fc-latent": 57.52389058470726
}
WER {
    "bottleneck": 0.25773540139198303,
    "bottleneck-latent": 0.039343612268567085,
    "encodec": 0.0746385008096695,
    "encodec-latent": 0.01696464605629444,
    "mel": 0.07383536919951439,
    "mel-latent": 0.024388889782130718,
    "w2v2fb": 0.0894194133579731,
    "w2v2fb-latent": 0.024340685456991196,
    "w2v2fc": 0.46374256908893585,
    "w2v2fc-latent": 0.045124998316168785
}
JSD {
    "bottleneck": 0.7077649235725403,
    "bottleneck-latent": Infinity,
    "encodec": 0.24726836383342743,
    "encodec-latent": 11.509827781126294,
    "mel": 0.21264910697937012,
    "mel-latent": Infinity,
    "

## PPG JSD evaluation

In [10]:
class JSDs:
    """PPG distances at multiple exponents"""

    def __init__(self):
        self.jsds = [
            promonet.evaluate.metrics.PPG(exponent)
            for exponent in torch.round(torch.arange(0.0, 2.0, 0.05), decimals=2)]

    def __call__(self):
        return {f'{jsd.exponent:02f}': jsd() for jsd in self.jsds}

    def update(self, predicted, target):
        # Compute PPG
        gpu = (
            None if predicted.device.type == 'cpu'
            else predicted.device.index)
        predicted = ppgs.from_audio(
            predicted,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        target = ppgs.from_audio(
            target,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        
        # Update metrics
        for jsd in self.jsds:
            jsd.update(predicted, target)

    def reset(self):
        for jsd in self.jsds:
            jsd.reset()

In [11]:
jsd_results = {}
jsd_file_results = {}
jsds = JSDs()
file_jsds = JSDs()
original_files = sorted(list(
    (promonet.EVAL_DIR / 'subjective' / 'original').glob('vctk*.wav')))
for config in configs:
    jsds.reset()
    jsd_file_results[config] = {}
    eval_directory = promonet.EVAL_DIR / 'subjective' / config
    shift089_files = sorted(list(eval_directory.glob('*shifted-089.wav')))
    shift112_files = sorted(list(eval_directory.glob('*shifted-112.wav')))
    for original, shift089, shift112 in zip(
        original_files,
        shift089_files,
        shift112_files
    ):
        jsds.update(
            promonet.load.audio(shift089).to(device),
            promonet.load.audio(original).to(device))
        jsds.update(
            promonet.load.audio(shift112).to(device),
            promonet.load.audio(original).to(device))
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift089).to(device),
            promonet.load.audio(original).to(device))
        jsd_file_results[config][shift089.stem] = file_jsds()
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift112).to(device),
            promonet.load.audio(original).to(device))
        jsd_file_results[config][shift112.stem] = file_jsds()
    jsd_results[config] = jsds()
print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))

JSD {
    "bottleneck": {
        "0.000000": 0.0027158150915056467,
        "0.050000": 0.12569043040275574,
        "0.100000": 0.2222704291343689,
        "0.150000": 0.29630085825920105,
        "0.200000": 0.35214531421661377,
        "0.250000": 0.3934328258037567,
        "0.300000": 0.4233401417732239,
        "0.350000": 0.44412246346473694,
        "0.400000": 0.45779773592948914,
        "0.450000": 0.4660024344921112,
        "0.500000": 0.47004860639572144,
        "0.550000": 0.4709796607494354,
        "0.600000": 0.4696318507194519,
        "0.650000": 0.46665164828300476,
        "0.700000": 0.4625588059425354,
        "0.750000": 0.45777440071105957,
        "0.800000": 0.45262089371681213,
        "0.850000": 0.4473276734352112,
        "0.900000": 0.4420848488807678,
        "0.950000": 0.437012642621994,
        "1.000000": 0.43220528960227966,
        "1.050000": 0.42771586775779724,
        "1.100000": 0.4235795736312866,
        "1.150000": 0.41980981826782227,


## Select exponent with highest correlation with WER

In [12]:
wer_file_results = {}
for config in configs:
    wer_file_results[config] = {}
    results_dir = promonet.RESULTS_DIR / config / 'vctk'
    for file in results_dir.glob('*.json'):
        if file.stem == 'results':
            continue
        with open(file) as file:
            result = json.load(file)
        for stem, scores in result['objective']['raw'].items():
            if 'shifted' in stem:
                wer_file_results[config][stem] = scores['-'.join(stem.split('-')[-2:])]['wer']

exponents = jsd_results['mel'].keys()
stems = wer_file_results['mel'].keys()

correlations = {}
for exponent in exponents:
    jsd_values, wer_values = [], []
    for config in configs:
        for stem in stems:
            jsd_values.append(jsd_file_results[config][stem][exponent])
            wer_values.append(wer_file_results[config][stem])
    correlations[exponent] = scipy.stats.pearsonr(jsd_values, wer_values)
print('Correlations', json.dumps(correlations, indent=4, sort_keys=True))

Correlations {
    "0.000000": [
        0.06779130852191431,
        0.002418838009404917
    ],
    "0.050000": [
        0.6820096444854464,
        8.545463392438829e-274
    ],
    "0.100000": [
        0.6828002692862649,
        1.1351766318355809e-274
    ],
    "0.150000": [
        0.6837990213797482,
        8.782569910600452e-276
    ],
    "0.200000": [
        0.6847979787613037,
        6.721554033064936e-277
    ],
    "0.250000": [
        0.6857852617960345,
        5.24724265628967e-278
    ],
    "0.300000": [
        0.6868297583726961,
        3.494505666272777e-279
    ],
    "0.350000": [
        0.6878354791979264,
        2.5456670819450134e-280
    ],
    "0.400000": [
        0.688827590342266,
        1.901293929271267e-281
    ],
    "0.450000": [
        0.6898052099657075,
        1.4598071647337976e-282
    ],
    "0.500000": [
        0.6907533203377103,
        1.1993237008660013e-283
    ],
    "0.550000": [
        0.6916610814632764,
        1.0859

In [14]:
optimal = '1.200000'
jsd_results_optim = {config: value[optimal] for config, value in jsd_results.items()}
print('JSD', json.dumps(jsd_results_optim, indent=4, sort_keys=True))

JSD {
    "bottleneck": 0.41640302538871765,
    "bottleneck-latent": 0.2026045173406601,
    "encodec": 0.23291300237178802,
    "encodec-latent": 0.16535606980323792,
    "mel": 0.2013757824897766,
    "mel-latent": 0.10632917284965515,
    "w2v2fb": 0.26156085729599,
    "w2v2fb-latent": 0.15279409289360046,
    "w2v2fc": 0.5244527459144592,
    "w2v2fc-latent": 0.1651759147644043
}
