In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json

import scipy
import torch
import ppgs

import promonet

In [None]:
gpu = 0
device = 'cpu' if gpu is None else f'cuda:{gpu}'

### Pitch and WER evaluation

In [None]:
configs = [
    'bottleneck',
    'encodec',
    'mel',
    'w2v2fb',
    'w2v2fc',
    'bottleneck-latent',
    'encodec-latent',
    'mel-latent',
    'w2v2fb-latent',
    'w2v2fc-latent']

In [None]:
pitch_results, wer_results, jsd_results = {}, {}, {}
for config in configs:
    with open(promonet.RESULTS_DIR / config / 'results.json') as file:
        result = json.load(file)
    pitch_results[config] = .5 * (result['shifted-089']['pitch'] + result['shifted-112']['pitch'])
    wer_results[config] = .5 * (result['shifted-089']['wer'] + result['shifted-112']['wer'])
    jsd_results[config] = .5 * (result['shifted-089']['ppg'] + result['shifted-112']['ppg'])
print('Pitch', json.dumps(pitch_results, indent=4, sort_keys=True))
print('WER', json.dumps(wer_results, indent=4, sort_keys=True))
print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))

## PPG JSD evaluation

In [None]:
class JSDs:
    """PPG distances at multiple exponents"""

    def __init__(self):
        self.jsds = [
            promonet.evaluate.metrics.PPG(exponent)
            for exponent in torch.round(torch.arange(0.0, 2.0, 0.05), decimals=2)]

    def __call__(self):
        return {f'{jsd.exponent:02f}': jsd() for jsd in self.jsds}

    def update(self, predicted, target):
        # Compute PPG
        gpu = (
            None if predicted.device.type == 'cpu'
            else predicted.device.index)
        predicted = ppgs.from_audio(
            predicted,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        target = ppgs.from_audio(
            target,
            promonet.SAMPLE_RATE,
            ppgs.RUNS_DIR / 'mel' / '00200000.pt',
            gpu)
        
        # Update metrics
        for jsd in self.jsds:
            jsd.update(predicted, target)

    def reset(self):
        for jsd in self.jsds:
            jsd.reset()

In [None]:
jsd_results = {}
jsd_file_results = {}
jsds = JSDs()
file_jsds = JSDs()
original_files = sorted(list(
    (promonet.EVAL_DIR / 'subjective' / 'original').glob('vctk*.wav')))
for config in configs:
    jsds.reset()
    jsd_file_results[config] = {}
    eval_directory = promonet.EVAL_DIR / 'subjective' / config
    shift089_files = sorted(list(eval_directory.glob('*shifted-089.wav')))
    shift112_files = sorted(list(eval_directory.glob('*shifted-112.wav')))
    for original, shift089, shift112 in zip(
        original_files,
        shift089_files,
        shift112_files
    ):
        jsds.update(
            promonet.load.audio(shift089).to(device),
            promonet.load.audio(original).to(device))
        jsds.update(
            promonet.load.audio(shift112).to(device),
            promonet.load.audio(original).to(device))
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift089).to(device),
            promonet.load.audio(original).to(device))
        jsd_file_results[config][shift089.stem] = file_jsds()
        file_jsds.reset()
        file_jsds.update(
            promonet.load.audio(shift112).to(device),
            promonet.load.audio(original).to(device))
        jsd_file_results[config][shift112.stem] = file_jsds()
    jsd_results[config] = jsds()
print('JSD', json.dumps(jsd_results, indent=4, sort_keys=True))

## Select exponent with highest correlation with WER

In [None]:
wer_file_results = {}
for config in configs:
    wer_file_results[config] = {}
    results_dir = promonet.RESULTS_DIR / config / 'vctk'
    for file in results_dir.glob('*.json'):
        if file.stem == 'results':
            continue
        with open(file) as file:
            result = json.load(file)
        for stem, scores in result['objective']['raw'].items():
            if 'shifted' in stem:
                wer_file_results[config][stem] = scores['-'.join(stem.split('-')[-2:])]['wer']

exponents = jsd_results['mel'].keys()
stems = wer_file_results['mel'].keys()

correlations = {}
for exponent in exponents:
    jsd_values, wer_values = [], []
    for config in configs:
        for stem in stems:
            jsd_values.append(jsd_file_results[config][stem][exponent])
            wer_values.append(wer_file_results[config][stem])
    correlations[exponent] = scipy.stats.pearsonr(jsd_values, wer_values)
print('Correlations', json.dumps(correlations, indent=4, sort_keys=True))

In [None]:
optimal = '1.200000'
jsd_results_optim = {config: value[optimal] for config, value in jsd_results.items()}
print('JSD', json.dumps(jsd_results_optim, indent=4, sort_keys=True))