In [None]:
from pathlib import Path
import ppgs
from json import load
import numpy as np
from matplotlib import pyplot as plt
import torch
from matplotlib.colors import PowerNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
eval_dir = Path(ppgs.EVAL_DIR)
models = [
    'bottleneck',
    'encodec',
    'mel',
    'w2v2fb',
    'w2v2fc'
]
eval_files = {model: eval_dir / model / 'overall-test.json' for model in models}

In [None]:
evals = {}
for model, eval_file in eval_files.items():
    with open(eval_file, 'r') as f:
        evals[model] = load(f)

In [None]:
timit_accuracies = {
    model: evals[model]['timit']['Accuracy/per-dataset'] for model in models
}
arctic_accuracies = {
    model: evals[model]['arctic']['Accuracy/per-dataset'] for model in models
}
common_voice_accuracies = {
    model: evals[model]['charsiu']['Accuracy/per-dataset'] for model in models
}

In [None]:
timit_accuracies

In [None]:
arctic_accuracies

In [None]:
common_voice_accuracies

In [None]:
average_accuracies = {}
for model in models:
    average_accuracies[model] = 0
    average_accuracies[model] += timit_accuracies[model]
    average_accuracies[model] += arctic_accuracies[model]
    average_accuracies[model] += common_voice_accuracies[model]
    average_accuracies[model] /= 3

In [None]:
average_accuracy_items = list(average_accuracies.items())
indices = list(reversed(np.argsort([item[1] for item in average_accuracy_items])))
indices

In [None]:
model_colors = ['red', 'green', 'blue', 'orange', 'purple', 'cyan', 'lime', 'magenta']
label_map = {
    'bottleneck': 'ASR bottleneck',
    'encodec': 'EnCodec',
    'mel': 'Mel spectrogram',
    'w2v2fb': 'Wav2vec 2.0',
    'w2v2fc': 'Charsiu',
    'w2v2fc-pretrained': 'Charsiu (Pretrained)'
}
top_legend_models = ['bottleneck', 'encodec', 'mel', 'w2v2fb', 'w2v2fc']
bottom_legend_models = [m for m in models if m not in top_legend_models]
figure, axes = plt.subplots(1, 4, sharey=True, figsize=(8, 2.8), width_ratios=[1, 1, 1, 1.4])
datasets = ['Common Voice', 'TIMIT', 'Arctic']
inter_figure_distance = 0.075 -0.0025
for i, accuracies in enumerate([common_voice_accuracies, timit_accuracies, arctic_accuracies]):
    ax = axes[i]
    ax.set_ylim(0.3, 0.9)
    for pos in ['top', 'bottom', 'left', 'right']:
        ax.spines[pos].set_visible(False)
    ax.tick_params(left=False, bottom=False)
    bar = ax.bar(range(len(accuracies)), [list(accuracies.values())[idx] for idx in indices], align='center', color=model_colors)
    ax.set_xticks(range(len(accuracies)), [label_map[k] for k in accuracies.keys()], rotation=45, rotation_mode='anchor', ha='right', visible=False)
    ax.set_title(f'{datasets[i]}')
    if i == 0:
        xmin = 0
        xmax = 1 + inter_figure_distance
    elif i == len(datasets) - 1:
        xmin = 0 - inter_figure_distance
        xmax = 1
    else:
        xmin = 0 - inter_figure_distance
        xmax = 1 + inter_figure_distance
    for y in [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        ax.axhline(y, linestyle='dashed', clip_on=False, xmin=xmin, xmax=xmax)

lax = axes[-1]
lax.axis('off')
legend_labels = [label_map[models[idx]] + f'\n(avg={average_accuracy_items[idx][1]:.3f})' for idx in indices]
top_legend = lax.legend(
    [b for i, b in enumerate(bar) if models[indices[i]] in top_legend_models],
    [l for i, l in enumerate(legend_labels) if models[indices[i]] in top_legend_models],
    loc='center',
    title=r'$\bf{Input\ representation:}$',
    frameon=False,
    fontsize=11
)
plt.subplots_adjust(wspace=0.05)
figure.savefig('framewise_accuracy.pdf', bbox_inches='tight', pad_inches=0)

In [None]:
for idx in indices:
    print(models[idx])

In [None]:
print(average_accuracies)

In [None]:
common_voice_accuracies

In [None]:
bars = [b for i, b in zip(indices, bar) if models[i] in top_legend_models]

In [None]:
bars[2]._x0

In [None]:
models

In [None]:
[models[indices[i]] in top_legend_models for i in range(5)]

In [None]:
bars = [list(bar)[i] for i in range(5) if models[indices[i]] in top_legend_models]

In [None]:
M = np.array([
    [0.8, 0.05, 0.08, 0.07],
    [0.1, 0.75, 0.1, 0.05],
    [0.01, 0.04, 0.9, 0.05],
    [0, 0.25, 0.05, 0.7]
])

box_center = (2, 2)

box_start_coords = (box_center[0]-0.75, box_center[1]-0.75)
box_width = 1.5
box_height = 1.5

fig, ax = plt.subplots(1)
gax = ax.matshow(M)
ax.add_patch(
    plt.Rectangle(
        box_start_coords,
        box_width,
        box_height,
        facecolor='none',
        edgecolor='red',
        linewidth=4
    )
)

In [None]:
M = torch.load(ppgs.EVAL_DIR / 'balanced' / 'overall-valid' / 'DistanceMatrix-aggregate-data.pt')
figure = plt.figure(dpi=400, figsize=(6, 6))
ax = figure.add_subplot()
mat = ax.matshow(M, norm=PowerNorm(gamma=1/3))
# mat = ax.matshow(self._normalized())
phones = ppgs.PHONEMES
num_phones = len(ppgs.PHONEMES)
ax.locator_params('both', nbins=num_phones)
ax.set_xticklabels([''] + phones, rotation='vertical')
ax.set_yticklabels([''] + phones)
ax.tick_params(axis='x', top=True, bottom=True, labelbottom=True, labeltop=True)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.1)
ax.figure.colorbar(mat, cax=cax)
figure.align_labels()

phone_pairs = [
    ('f', 'v'),
    ('s', 'z'),
    ('sh', 'zh')
]

padding = 0.5 + 0.2

for phone0, phone1 in phone_pairs:

    idx0 = ppgs.PHONEMES.index(phone0)
    idx1 = ppgs.PHONEMES.index(phone1)
    indices = [idx0, idx1]

    for center in [indices, list(reversed(indices))]:
    
        box_start_coords = (center[0]-padding, center[1]-padding)
        box_width = padding*2
        box_height = padding*2
        
        ax.add_patch(
            plt.Rectangle(
                box_start_coords,
                box_width,
                box_height,
                facecolor='none',
                edgecolor='red',
                linewidth=0.5
            )
        )
figure.savefig('distance_matrix.pdf', bbox_inches='tight', pad_inches=0)