In [None]:
# Draw FSC/SNIPS

from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from utils import add_bow_feat_df
from generate_classifier_figures import _decision_tree, _relabel

meta = {
    "hubert-base-ls960": ("HuBERT-base", "-", "C0"),
    "wav2vec2-base": ("wav2vec 2.0-base", ":", "C4"),
    "hubert-large-ll60k": ("HuBERT-large", ".-", "C1"),
    "wav2vec2-large": ("wav2vec 2.0-large", ".:", "C3"),
    "wavlm-large": ("WavLM-large", "+-", "C5"),
    "wav2vec2-xls-r-300m": ("XLS-R-300M", "+:", "C2"),
}


# SNIPS START
fig, ax = plt.subplots(1, 3, figsize=(8, 2))

df = _relabel(pd.read_pickle("tables/snips_close_field.df.pkl"), Path("datasets/mase/slu_splits/snips_close_field/challenge_splits"))
_df = _decision_tree(add_bow_feat_df(df, use_train=True))
accs = {
    p.stem.split("_")[3]: pd.DataFrame(pickle.load(open(p, "rb")))
    for p in Path("tables").glob("snips*pool-mean*challenge*.pkl")
}

for i, task in enumerate(("speaker_test", "utterance_test"), 1):
    for key, (name, marker, color) in meta.items():
        ax[i].plot(accs[f"model-{key}"][task], marker, label=name, color=color)
    ax[i].axhline((_df[_df.split == task].pred == _df[_df.split == task].label).mean(), label="Bag of Words", ls="--", c="black")
    title = {"speaker_test": "SPK", "utterance_test": "UTT"}[task]
    ax[i].set_title(f"Challenge ({title})")


df = _relabel(pd.read_pickle("tables/snips_close_field.df.pkl"), Path("datasets/mase/slu_splits/snips_close_field/original_splits"))
_df = _decision_tree(add_bow_feat_df(df, use_train=True))
accs = {
    p.stem.split("_")[3]: pd.DataFrame(pickle.load(open(p, "rb")))
    for p in Path("tables").glob("snips*pool-mean*original*.pkl")
}

for task in ("test", ):
    for key, (name, marker, color) in meta.items():
        ax[0].plot(accs[f"model-{key}"][task], marker, label=name, color=color)
    ax[0].axhline((_df[_df.split == task].pred == _df[_df.split == task].label).mean(), label="Bag of Words", ls="--", c="black")
    ax[0].set_title("Original")
    ax[0].set_xlabel("Layer index")
    ax[0].set_ylabel("Accuracy")
ax[2].legend(ncol=4, loc="center", bbox_to_anchor=(-0.75, 1.35))
plt.savefig("snips.pdf", bbox_inches="tight")

# SNIPS END

# FSC START
fig, ax = plt.subplots(1, 3, figsize=(8, 2))

df = _relabel(pd.read_pickle("tables/fluent_speech_commands.df.pkl"), Path("datasets/mase/slu_splits/fluent_speech_commands/challenge_splits"))
_df = _decision_tree(add_bow_feat_df(df, use_train=True))
accs = {
    p.stem.split("_")[3]: pd.DataFrame(pickle.load(open(p, "rb")))
    for p in Path("tables").glob("fluent_*pool-mean*challenge*.pkl")
}

for i, task in enumerate(("speaker_test", "utterance_test"), 1):
    for key, (name, marker, color) in meta.items():
        ax[i].plot(accs[f"model-{key}"][task].dropna(), marker, label=name, color=color)
    ax[i].axhline((_df[_df.split == task].pred == _df[_df.split == task].label).mean(), label="Bag of Words", ls="--", c="black")
    title = {"speaker_test": "SPK", "utterance_test": "UTT"}[task]
    ax[i].set_title(f"Challenge ({title})")
    if task == "speaker_test":
        ax[i].set_ylim(0.13, 1.025)


df = _relabel(pd.read_pickle("tables/fluent_speech_commands.df.pkl"), Path("datasets/mase/slu_splits/fluent_speech_commands/original_splits"))
_df = _decision_tree(add_bow_feat_df(df, use_train=True))
accs = {
    p.stem.split("_")[3]: pd.DataFrame(pickle.load(open(p, "rb")))
    for p in Path("tables").glob("fluent_*pool-mean*original*.pkl")
}

for task in ("test", ):
    for key, (name, marker, color) in meta.items():
        ax[0].plot(accs[f"model-{key}"][task].dropna(), marker, label=name, color=color)
    ax[0].axhline((_df[_df.split == task].pred == _df[_df.split == task].label).mean(), label="Bag of Words", ls="--", c="black")
    ax[0].set_title("Original")
    ax[0].set_xlabel("Layer index")
    ax[0].set_ylabel("Accuracy")
ax[2].legend(ncol=4, loc="center", bbox_to_anchor=(-0.75, 1.35))
plt.savefig("fsc.pdf", bbox_inches="tight")
# FSC END

In [None]:
# Draw LibriSpeech/MSW

import matplotlib.pyplot as plt
import pickle
from itertools import product
import numpy as np
import scipy.stats
import matplotlib
matplotlib.rcParams.update({'font.size': 14})

pairs = {
    "MSW": {
        "left": "MSW_model-wavlm-large_slice-True_spk-x_size-2000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "MSW_model-wav2vec2-xls-r-300m_slice-True_spk-x_size-2000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 2},
        "title": ("WavLM-Large", "XLS-R-300M"),
    },
    "hubert-spk": {
        "left": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-full_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-full_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 1,
        "speakers": (5142, 2412, 6313, 1580, 2277),
        "normalizer": ("none", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 2},
        "title": ("HuBERT-large", "HuBERT-large (Norm.)"),
    },
    "hubert": {
        "left": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("none", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 3},
        "title": ("Audio slicing", "Audio slicing (Norm.)"),
    },
    "hubert-slice": {
        "left":  "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-False_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-False_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("none", "subtract"),
        # "legend": ("left", {"loc": "lower left"}),
        "title": ("Feature slicing", "Feature slicing (Norm.)"),
    },
    "hubert-pool": {
        "left": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-10000_pool-center_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-10000_pool-median_cosine_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 3},
        "title": ("Center pooling (Norm.)", "Centroid pooling (Norm.)"),
    },
    "w2v2": {
        "left": "librispeech-dev-clean-test-clean_model-wav2vec2-large_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-wav2vec2-large_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("none", "subtract"),
        # "legend": ("left", {"loc": "upper left"}),
        "title": ("wav2vec 2.0", "wav2vec 2.0 (Norm.)"),
    },
    "base": {
        "left": "librispeech-dev-clean-test-clean_model-hubert-base-ls960_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-wav2vec2-base_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "legend": {"loc": "center", "bbox_to_anchor": (-0.2, 1.4), "ncols": 3},
        "title": ("HuBERT-base (Norm.)", "wav2vec 2.0-base (Norm.)"),
    },
    "large": {
        "left": "librispeech-dev-clean-test-clean_model-hubert-large-ll60k_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-wav2vec2-large_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "title": ("HuBERT-large (Norm.)", "wav2vec 2.0-large (Norm.)"),
    },
    "wavlm-xls": {
        "left": "librispeech-dev-clean-test-clean_model-wavlm-large_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "right": "librispeech-dev-clean-test-clean_model-wav2vec2-xls-r-300m_slice-True_spk-x_size-10000_pool-mean_seed-x_dist-cos_sim.dist.pkl",
        "num_seeds": 5,
        "speakers": ("everyone", ),
        "normalizer": ("subtract", "subtract"),
        "title": ("WavLM-large (Norm.)", "XLS-R-300M (Norm.)"),
    },
}

keys = {
    "random": ("C0", "", "Random"),
    "synonym": ("C1", "x", "Synonym"),
    "homophone": ("C2", "o", "Near homophone"),
    "speaker": ("C3", "|", "Same speaker"),
    "same_word": ("C4", "^", "Same word"),
}

def mean_confidence_interval(data, confidence=0.95):
    # Obtained from https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h


for pair_key, meta in pairs.items():
    fig, axes = plt.subplots(1, 2, figsize=(8, 3))
    for loc, ax, normalizer, title in zip(("left", "right"), axes, meta["normalizer"], meta["title"]):
        seedwise_dists = []
        for seed, speaker in product(range(meta["num_seeds"]), meta["speakers"]):
            dists_path = Path("/Users/kwangheechoi/Desktop/dev_tables") / meta[loc].replace("spk-x", f"spk-{speaker}").replace("seed-x", f"seed-{seed}")
            dists = pickle.load(open(dists_path, "rb"))
            seedwise_dists.append({
                k: [mean_confidence_interval(v)[0] for v in vs]
                for k, vs in dists.items()
            })
        agg_dists = {}
        for key in seedwise_dists[0].keys():
            agg_dists[key] = []
            for layer in range(len(seedwise_dists[0]["random"])):
                vs = [seedwise_dists[i][key][layer] for i in range(len(seedwise_dists))]
                agg_dists[key].append(mean_confidence_interval(vs))

        for key, tuples in agg_dists.items():
            value = np.array([t[0] for t in tuples])
            bound = np.array([t[1] for t in tuples])
            if normalizer == "subtract":
                value -= np.array([t[0] for t in agg_dists["random"]])
            color, marker, label = keys[key]
            style = "dotted" if (key == "random" and normalizer == "subtract") else "solid"
            ax.plot(value, label=label, marker=marker, color=color, linestyle=style)
            ax.fill_between(np.arange(len(value)), value-bound, value+bound, alpha=0.2)

        ax.set_title(title)
        ax.set_ylabel("Norm. cos. sim." if normalizer == "subtract" else "Cos. sim.")
        ax.set_xlabel("Layer index")
        if title in ("HuBERT-large (Norm.)", "Audio slicing (Norm.)", "Center pooling (Norm.)", "Centroid pooling (Norm.)"):
            ax.set_ylim(-0.02, 0.3)
    plt.tight_layout()
    if "legend" in meta:
        plt.legend(**meta["legend"])
    plt.savefig(f"figs/{pair_key}.pdf", bbox_inches="tight")
    plt.show()
