In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json

import scipy
import torch
import ppgs

import promonet

  from .autonotebook import tqdm as notebook_tqdm


### Pitch and WER evaluation

In [3]:
configs = [
    'bottleneck',
    'encodec',
    'mel',
    'w2v2fb',
    'w2v2fc',
    'bottleneck-latent',
    'encodec-latent',
    'mel-latent',
    'w2v2fb-latent',
    'w2v2fc-latent']

In [4]:
pitch_results, wer_results = {}, {}
for config in configs:
    with open(promonet.RESULTS_DIR / config / 'results.json') as file:
        result = json.load(file)
    pitch_results[config] = .5 * (result['shifted-089']['pitch'] + result['shifted-112']['pitch'])
    wer_results[config] = .5 * (result['shifted-089']['wer'] + result['shifted-112']['wer'])
print('Pitch', json.dumps(pitch_results, indent=4, sort_keys=True))
print('WER', json.dumps(wer_results, indent=4, sort_keys=True))

Pitch {
    "bottleneck": 65.9216970205307,
    "bottleneck-latent": 55.776297301054,
    "encodec": 56.549616158008575,
    "encodec-latent": 183.79185497760773,
    "mel": 56.010619550943375,
    "mel-latent": 207.66302347183228,
    "w2v2fb": 59.523455053567886,
    "w2v2fb-latent": 57.2452649474144,
    "w2v2fc": 61.78075075149536,
    "w2v2fc-latent": 59.1887041926384
}
WER {
    "bottleneck": 0.2779349163174629,
    "bottleneck-latent": 0.055785488337278366,
    "encodec": 0.10180694609880447,
    "encodec-latent": 0.026043497025966644,
    "mel": 0.07438646629452705,
    "mel-latent": 0.023927961476147175,
    "w2v2fb": 0.0910343136638403,
    "w2v2fb-latent": 0.024351066909730434,
    "w2v2fc": 0.5073660314083099,
    "w2v2fc-latent": 0.021442336961627007
}


## PPG JSD evaluation

In [5]:
class JSDs:
    """PPG distances at multiple exponents"""

    def __init__(self):
        self.jsds = [
            promonet.evaluate.metrics.PPG(exponent)
            for exponent in torch.round(torch.arange(0.0, 2.0, 0.05), decimals=2)]

    def __call__(self):
        return {f'{jsd.exponent:02f}': jsd() for jsd in self.jsds}

    def update(self, predicted, target):
        # Compute PPG
        gpu = (
            None if predicted.device.type == 'cpu'
            else predicted.device.index)
        predicted = ppgs.from_audio(
            predicted,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        target = ppgs.from_audio(
            target,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        
        # Update metrics
        for jsd in self.jsds:
            jsd.update(predicted, target)

    def reset(self):
        for jsd in self.jsds:
            jsd.reset()

In [6]:
jsd_results = {}
jsd_file_results = {}
jsds = JSDs()
file_jsds = JSDs()
original_files = sorted(list(
    (promonet.EVAL_DIR / 'subjective' / 'original').glob('*.wav')))
for config in configs:
    jsds.reset()
    jsd_file_results[config] = {}
    eval_directory = promonet.EVAL_DIR / 'subjective' / config
    shift089_files = sorted(list(eval_directory.glob('*shifted-089.wav')))
    shift112_files = sorted(list(eval_directory.glob('*shifted-112.wav')))
    for original, shift089, shift112 in zip(
        original_files,
        shift089_files,
        shift112_files
    ):
        jsds.update(
            promonet.load.audio(shift089),
            promonet.load.audio(original))
        jsds.update(
            promonet.load.audio(shift112),
            promonet.load.audio(original))
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift089),
            promonet.load.audio(original))
        jsd_file_results[config][shift089.stem] = file_jsds()
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift112),
            promonet.load.audio(original))
        jsd_file_results[config][shift112.stem] = file_jsds()
    jsd_results[config] = jsds()
print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))

JSD {
    "bottleneck": {
        "0.000000": 0.0,
        "0.050000": 0.376953125,
        "0.100000": 0.462890625,
        "0.150000": 0.48046875,
        "0.200000": 0.578125,
        "0.250000": 0.63671875,
        "0.300000": 0.6484375,
        "0.350000": 0.6484375,
        "0.400000": 0.671875,
        "0.450000": 0.69140625,
        "0.500000": 0.69140625,
        "0.550000": 0.68359375,
        "0.600000": 0.67578125,
        "0.650000": 0.67578125,
        "0.700000": 0.671875,
        "0.750000": 0.671875,
        "0.800000": 0.66015625,
        "0.850000": 0.65234375,
        "0.900000": 0.6484375,
        "0.950000": 0.63671875,
        "1.000000": 0.63671875,
        "1.050000": 0.63671875,
        "1.100000": 0.62890625,
        "1.150000": 0.62890625,
        "1.200000": 0.6328125,
        "1.250000": 0.62890625,
        "1.300000": 0.62109375,
        "1.350000": 0.6171875,
        "1.400000": 0.6171875,
        "1.450000": 0.6171875,
        "1.500000": 0.6171875,
   

## Select exponent with highest correlation with WER

In [20]:
wer_file_results = {}
for config in configs:
    wer_file_results[config] = {}
    results_dir = promonet.RESULTS_DIR / config / 'vctk'
    for file in results_dir.glob('*.json'):
        if file.stem == 'results':
            continue
        with open(file) as file:
            result = json.load(file)
        for stem, scores in result['objective']['raw'].items():
            if 'shifted' in stem:
                wer_file_results[config][stem] = scores['-'.join(stem.split('-')[-2:])]['wer']

exponents = jsd_results['mel'].keys()
stems = wer_file_results['mel'].keys()

correlations = {}
for exponent in exponents:
    jsd_values, wer_values = [], []
    for config in configs:
        for stem in stems:
            jsd_values.append(jsd_file_results[config][stem][exponent])
            wer_values.append(wer_file_results[config][stem])
    correlations[exponent] = scipy.stats.pearsonr(jsd_values, wer_values)
print('Correlations', json.dumps(correlations, indent=4, sort_keys=True))

Correlations {
    "0.000000": [
        NaN,
        NaN
    ],
    "0.050000": [
        0.2723851718136634,
        2.33000062738939e-35
    ],
    "0.100000": [
        0.32478480698513734,
        2.3666119886274414e-50
    ],
    "0.150000": [
        0.36426268338931805,
        8.439061081298532e-64
    ],
    "0.200000": [
        0.3776331493977042,
        8.319150774799756e-69
    ],
    "0.250000": [
        0.38456422449864625,
        1.7010813151905333e-71
    ],
    "0.300000": [
        0.3945249233807716,
        1.77980430326978e-75
    ],
    "0.350000": [
        0.40381665348783835,
        2.581289601491177e-79
    ],
    "0.400000": [
        0.4028310730807355,
        6.680780130471154e-79
    ],
    "0.450000": [
        0.4043719632684271,
        1.508442235527858e-79
    ],
    "0.500000": [
        0.4059339493799188,
        3.310484072039207e-80
    ],
    "0.550000": [
        0.4074567230863725,
        7.488453611476225e-81
    ],
    "0.600000": [


In [19]:
optimal = '0.600000'
jsd_results_optim = {config: value[optimal] for config, value in jsd_results.items()}
print('JSD', json.dumps(jsd_results_optim, indent=4, sort_keys=True))

JSD {
    "bottleneck": 0.67578125,
    "bottleneck-latent": 0.42578125,
    "encodec": 0.470703125,
    "encodec-latent": 0.345703125,
    "mel": 0.458984375,
    "mel-latent": 0.26171875,
    "w2v2fb": 0.51171875,
    "w2v2fb-latent": 0.3515625,
    "w2v2fc": 0.765625,
    "w2v2fc-latent": 0.380859375
}
