In [78]:
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import pandas as pd
from omegaconf import OmegaConf
import torch
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
from matplotlib.gridspec import GridSpec
from tqdm.notebook import tqdm
import bm.train
import os

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [62]:
os.chdir(Path(bm.train.__file__).parent.parent)

In [63]:
output_dir = bm.train.main.dora.dir
eval_dir = output_dir / "eval" / "signatures"
sigs_to_eval = [p.name for p in (output_dir / "grids" / "ablation_final").iterdir()]
assert output_dir.exists()
assert eval_dir.exists()


In [89]:
# Select signatures
valid_sigs = [sig for sig in sigs_to_eval if (eval_dir / sig / "vocab_segment.pth").is_file()]
print(set(sigs_to_eval) - set(valid_sigs))
configs = [OmegaConf.load(eval_dir / sig / "solver_config.yaml") for sig in valid_sigs]
run_df = pd.DataFrame({
    "sig":valid_sigs,
})
run_df["dataset"] = ["-".join(conf.dset.selections) for conf in configs]
run_df["seed"] = [conf.seed for conf in configs]
run_df["features"] = ["-".join(conf.dset.features) for conf in configs]
run_df["loss"] = [conf.optim.loss for conf in configs]
# run_df["is_random"] = [conf.test.wer_random for conf in configs]
run_df["max_scale"] = [conf.norm.max_scale for conf in configs]
run_df["n_mels"] = [conf.dset.features_params.MelSpectrum.n_mels for conf in configs]
run_df["deepmel"] = [bool(conf.clip.arch) for conf in configs]
run_df["ft"] = [conf.optim.epochs == 1 and not conf.test.wer_random for conf in configs]
run_df["random"] = [conf.test.wer_random for conf in configs]

run_df["batch_size"] = [conf.optim.batch_size for conf in configs]
run_df["ft"] = [conf.optim.lr == 0 for conf in configs]
run_df["dropout"] = [conf.simpleconv.merger_dropout > 0 for conf in configs]
run_df["gelu"] = [bool(conf.simpleconv.gelu) for conf in configs]
run_df["skip"] = [bool(conf.simpleconv.skip) for conf in configs]
run_df["initial"] = [bool(conf.simpleconv.initial_linear) for conf in configs]
run_df["complex"] = [bool(conf.simpleconv.complex_out) for conf in configs]
run_df["subject_lay"] = [bool(conf.simpleconv.subject_layers) for conf in configs]
run_df["subject_emb"] = [bool(conf.simpleconv.subject_dim) for conf in configs]
run_df["attention"] = [bool(conf.simpleconv.merger) for conf in configs]
run_df["glu"] = [bool(conf.simpleconv.glu) for conf in configs]
run_df["depth"] = [conf.simpleconv.depth for conf in configs]

def get_name(row):
    if not row.dropout:
        return r"\wo spatial attention dropout"
    if not row.gelu:
        return r"\wo GELU, \w ReLU"
    if not row.skip:
        return r"\wo skip connections"
    if not row.initial:
        return r"\wo initial 1x1 conv."
    if not row.complex:
        return r"\wo final convs"
    if not row.attention:
        return r"\wo spatial attention"
    if not row.glu:
        return r"\wo non-residual GLU conv."
    if not row.subject_lay:
        if row.subject_emb:
            return r"\w subj. embedding*"
        else:
            return r"\wo subject-specific layer"
    if row.depth == 5:
        return r"less deep"
    elif row.max_scale != 20:
        if row.max_scale == 100:
            return "\w clamp=100"
        else:
            if row.autoreject:
                return "autoreject"
            return "\wo clamping brain signal"
    return "Our model"
    
run_df['name'] = run_df.apply(get_name, axis=1)
run_df = run_df[(run_df.loss == "clip")]

set()


In [90]:
for name in run_df.name.unique():
    print(name, (run_df.name == name).sum())

Our model 12
\wo spatial attention 12
\wo spatial attention dropout 12
\wo non-residual GLU conv. 12
\wo initial 1x1 conv. 12
\wo GELU, \w ReLU 12
\wo skip connections 12
\wo final convs 12
\wo subject-specific layer 12
\w subj. embedding* 12
\wo clamping brain signal 12


In [79]:
run_df.sort_values('dataset', ignore_index=True, inplace=True)
run_df[run_df.name == 'base'].dataset.tolist()

['audio_mous',
 'audio_mous',
 'audio_mous',
 'brennan2019',
 'brennan2019',
 'brennan2019',
 'broderick2019',
 'broderick2019',
 'broderick2019',
 'gwilliams2022',
 'gwilliams2022',
 'gwilliams2022']

In [80]:
# For main table 
# %time
run_df.sort_values(['dataset', 'seed'], ignore_index=True, inplace=True)
table_sigs = {}
table_sigs['reference'] = run_df[run_df.name == "base"].sig.tolist()
for name in run_df.name.unique():
    if name == "base": continue
    table_sigs[name] = run_df[run_df.name == name].sig.tolist()
table_dataset_names = run_df[run_df.name == 'base'].dataset.tolist()
seed_names = run_df[run_df.name == 'base'].seed.tolist()
print(table_dataset_names)
print(seed_names)


['audio_mous', 'audio_mous', 'audio_mous', 'brennan2019', 'brennan2019', 'brennan2019', 'broderick2019', 'broderick2019', 'broderick2019', 'gwilliams2022', 'gwilliams2022', 'gwilliams2022']
[2036, 2037, 2038, 2036, 2037, 2038, 2036, 2037, 2038, 2036, 2037, 2038]


# Common functions

In [81]:
def load_data_from_sig(sig, level="word"):
    """
    Load data from solver signature
    - probs (torch tensor): probability on vocab [N, V]
    - vocab (torch tensor): vocab of word hashes [V]
    - words (torch tensor): the word hash for each sample [N]
    - metadata (panda dataframe) of len [N] which contains for each sample:
           'word_hashes', 'word_indices', 'seq_indices',
           'word_strings', 'subject_id', 'recording_id'
    """
    assert level in ["word", "segment"], "level should be 'word' or 'segment'"
    probs = torch.load(eval_dir / sig / f"probs_{level}.pth") 
    vocab = torch.load(eval_dir / sig / f"vocab_{level}.pth") # vocab (hashes)
    metadata = pd.read_csv(eval_dir / sig / "metadata.csv", index_col=0) 
    words = torch.LongTensor(metadata[f"{level}_hashes"]) # for each sample, the word (hashes)
    assert probs.shape == (len(words), len(vocab))
    assert len(words) == len(metadata)
    metadata["idx"] = range(len(words))
    metadata["word_strings"] = metadata["word_strings"].str.lower()

    return probs, vocab, words, metadata


def get_accuracy_from_probs(probs, row_labels, col_labels, topk=10):
    """
    probs: for each row, the probability distribution over a vocab
    returns the topk accuracy that the topk best predicted labels
    match the row_labels
    Inputs:
        probs: of shape [B, V] probability over vocab, each row sums to 1
        row_labels: of shape [B] true word for each row
        col_labels: [V] word that correspond to each column
        topk: int
    Returns: float scalar, topk accuracy
    """
    assert len(row_labels) == len(probs)
    assert len(col_labels) == probs.shape[1]

    # Extract topk indices
    idx = probs.topk(topk, dim=1).indices

    # Get the corresponding topk labels
    whs = col_labels[idx.view(-1)].reshape(idx.shape)

    # 1 if the labels matches with the targets
    correct = ((whs == row_labels[:, None]).any(1)).float()
    assert len(correct) == len(row_labels)

    # Average across samples
    acc = correct.mean()

    return acc.item()

def eval_acc(sigs, level="word", add_baselines=True, per_sub=False):
    """
    Return accuracy dataframe for multiple sigs 
    level: whether to return word or segment level accuracy
    """
    futures = []
    acc = []
    with ProcessPoolExecutor(20) as pool:
        for sig in sigs:
            future = pool.submit(eval_acc_one_sig, sig, level=level, add_baselines=add_baselines, per_sub=per_sub)
            futures.append((sig, future))
        for sig, future in tqdm(futures):
            try:
                acc_sig = future.result()
            except Exception:
                print("ERROR WITH", sig)
                continue
            acc_sig["sig"] = sig
            acc.append(acc_sig)
    acc = pd.concat(acc)
    return acc


def eval_acc_nopool(sigs, level="word", add_baselines=True, per_sub=False):
    """
    Return accuracy dataframe for multiple sigs 
    level: whether to return word or segment level accuracy
    """
    acc = []
    for sig in tqdm(sigs):
        try:
            acc_sig = eval_acc_one_sig(sig, level=level, add_baselines=add_baselines, per_sub=per_sub)
        except Exception:
            print("ERROR WITH", sig)
            continue
        acc_sig["sig"] = sig
        acc.append(acc_sig)
    acc = pd.concat(acc)
    return acc

def eval_acc_one_sig(sig, topks=(1, 5, 10), level="word", add_baselines=True, per_sub=False):
    """
    Return accuracy dataframe from one solver signature
    level: whether to return `word` or `segment` level accuracy
    """
    # Load data
    probs, vocab, words, meta = load_data_from_sig(sig, level=level)
    meta = meta.reset_index()
    
    if not per_sub:
        meta["subject_id"] = "nan"

    # Compute acc
    acc_df = []
    for topk in topks:
        for subject, metasub in meta.groupby("subject_id"):
            idx = metasub["index"].values

            # --- Acc ---
            acc = get_accuracy_from_probs(probs[idx], words[idx], vocab, topk=topk)

            out = {
                "acc":acc,
                "topk":topk,
                "subject_id": subject, 
            }
            if add_baselines:
                # --- Baseline on vocab ---
                # equivalent to : shuffle targets vocab (inf times)
                # equivalent to : output uniform prob on vocab
                # equivalent to : 1/vocab_len
                rand_probs_vocab = torch.ones_like(probs[idx]) / len(vocab)
                out["baseline_vocab"] = get_accuracy_from_probs(rand_probs_vocab, words[idx], vocab, topk=topk)

                # --- Baseline on words ---
                # equivalent to : shuffle word targets before aggregating on vocab (inf times)
                # equivalent to : output uniform prob on samples
                # equivalent to : each_word_freq
                check_vocab, counts = torch.unique(words, return_counts=True)
                rand_probs_words = torch.stack([counts/len(words)]*len(probs[idx]))
                out["baseline"] = get_accuracy_from_probs(rand_probs_words, words[idx], vocab, topk=topk)
                
            # Update
            acc_df.append(out)
    acc_df = pd.DataFrame(acc_df)
    return acc_df



In [82]:
acc_df = []
for row_label, sigs in table_sigs.items():
    print(f"Computing acc for {row_label} {sigs}")
    row_acc = eval_acc_nopool(sigs, level="segment", add_baselines=False, per_sub=True,)
    row_acc["dataset"] = [table_dataset_names[sigs.index(k)] for k in row_acc.sig.values]
    row_acc["seed"] = [seed_names[sigs.index(k)] for k in row_acc.sig.values]
    row_acc["row_label"] = row_label
    acc_df.append(row_acc)
acc_df = pd.concat(acc_df)

def dset_order(name):
    return name.map({
       'audio_mous': 0, 
       'gwilliams2022': 1,
       'broderick2019': 2,
       'brennan2019': 3,
    })

acc_df = acc_df.sort_values(["dataset"], key=dset_order);

Computing acc for reference ['01cd28fc', '49efae33', '6b9ba19e', 'd605ce2f', 'bd231c7a', 'ce18f336', '1dd888f0', 'bd30ebf1', 'bd04f96e', 'ceb23552', 'bd7d06d2', 'e059e7e5']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo clamping brain signal ['b476448e', 'bea7f00a', 'd1d37637', 'f3cae34e', '84f95be4', '586f1ad0', '22fb342c', '7f15adbc', 'c1150d28', 'b978fc86', '9ed67c53', 'fbdd79aa']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo final convs ['480110a3', 'e521c72b', 'cfe0aefd', 'de507ded', '5f8b5d6a', 'bba535c9', '666e391f', 'a5e3ebae', '626f961e', '092d508e', '5b38fa38', '9567ff2c']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo skip connections ['fb2cced8', 'c4098073', 'ac741d87', '27332012', 'a781c6a7', '5349feef', 'bb4f4ee9', 'd3399072', '15546f79', '70a0e185', 'd9cec30d', '9783831e']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo subject-specific layer ['9e8b929d', 'ac5b3b57', '91c5c41b', '2325aad9', '00e0bdcd', '71dd6af4', 'e48c36ea', '5c633e70', '76b0e85a', 'a5acd4a3', '4479100b', '78d5d473']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo initial 1x1 conv. ['8e9c33b1', '036741ea', '1481dbbb', '5d3de32c', '1d014dcd', '6019c47b', '08cf8898', '22d2260a', '87c1ffa6', '81bd9784', '0d4f43b4', '5d7a51d4']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo non-residual GLU conv. ['55751f43', '832d9622', '3a9ae364', '493023a0', '0f34bbaf', 'a4b730f1', 'fa2d2578', '2ae3913a', '850b80f6', 'ff85bfd6', '0896a4a8', '0f17594c']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \w subj. embedding* ['f33cf48b', '3cf51478', '745a6231', '64c1362f', '8ababc44', 'b8d46745', '6577f78f', '0e6262e9', '203c64a9', 'c3b555db', '73c5dce9', 'd733b028']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo spatial attention ['ec29e81e', 'c32f25b6', '8a7419d6', '2aecc316', '437a096c', '58a6fa89', 'fc475d47', 'ea519a2a', 'd4cf1bbe', 'ca8c38f8', '1bd6de99', '57a34964']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo GELU, \w ReLU ['4438d221', '766c6d08', '4e9d4727', '16d17b9c', '957fe7a4', '0c80dd82', '38449131', '6324264b', '3d61e840', '048edc79', '86811260', '46c3d964']


  0%|          | 0/12 [00:00<?, ?it/s]

Computing acc for \wo spatial attention dropout ['2dd6ddbe', 'b23c2cfa', '22ca6dba', 'd890f8eb', 'e5c6b552', '4c3a83bf', '41831434', '8d13203d', '5c64a266', '85f40582', '55c4a03c', 'd43b1b05']


  0%|          | 0/12 [00:00<?, ?it/s]

In [83]:
from scipy.stats import wilcoxon
acc_df = acc_df[acc_df.topk==10]
pivot = pd.pivot_table(acc_df, index=["dataset", "subject_id"], columns=["row_label"], values="acc")
pivot = pivot[table_sigs.keys()] # Reorder rows

# Average score
means = pivot.groupby("dataset").agg("mean").T

# Standard Error of the Means (std too high)
sems = pivot.groupby("dataset").agg("sem").T

# Difference in score between reference and each row
deltas = means - means.loc["reference"]

# Pvalues: grouped by dataset => 1 pvalue per row and dataset
pvalues = pivot.groupby("dataset").corr(method=lambda x, y: wilcoxon(x, y)[1])
pvalues = pvalues.reset_index().query("row_label=='reference'").set_index("dataset").T.drop("row_label")
pvalues = pvalues.astype(float)

# Pvalues: aggregated on datasets => 1 pvalue per row
pvalues_agg_dset = pivot.corr(method=lambda x, y: wilcoxon(x, y)[1])[["reference",]]



In [84]:
means

dataset,audio_mous,brennan2019,broderick2019,gwilliams2022
row_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
reference,0.675469,0.25741,0.134172,0.669869
\wo clamping brain signal,0.014541,0.140769,0.009778,0.226597
\wo final convs,0.674825,0.19021,0.111579,0.653442
\wo skip connections,0.654489,0.242466,0.10951,0.62722
\wo subject-specific layer,0.424489,0.201883,0.068915,0.447212
\wo initial 1x1 conv.,0.628566,0.220709,0.118565,0.641835
\wo non-residual GLU conv.,0.66952,0.060001,0.068657,0.664579
\w subj. embedding*,0.647897,0.293972,0.146858,0.670004
\wo spatial attention,0.658781,0.205716,0.118991,0.621047
"\wo GELU, \w ReLU",0.658231,0.245611,0.125274,0.652161


In [85]:
sems

dataset,audio_mous,brennan2019,broderick2019,gwilliams2022
row_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
reference,0.01299,0.018229,0.015285,0.026932
\wo clamping brain signal,0.000688,0.009408,1.9e-05,0.019061
\wo final convs,0.013111,0.014835,0.011719,0.02658
\wo skip connections,0.013221,0.015965,0.011245,0.028014
\wo subject-specific layer,0.01587,0.013817,0.008554,0.02773
\wo initial 1x1 conv.,0.013233,0.01544,0.0135,0.027369
\wo non-residual GLU conv.,0.01296,0.003823,0.007538,0.026894
\w subj. embedding*,0.013135,0.019236,0.015858,0.026313
\wo spatial attention,0.013168,0.014461,0.013219,0.027978
"\wo GELU, \w ReLU",0.01342,0.016774,0.013479,0.027318


In [86]:
deltas

dataset,audio_mous,brennan2019,broderick2019,gwilliams2022
row_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
reference,0.0,0.0,0.0,0.0
\wo clamping brain signal,-0.660927,-0.116641,-0.124394,-0.443271
\wo final convs,-0.000643,-0.0672,-0.022593,-0.016426
\wo skip connections,-0.02098,-0.014944,-0.024662,-0.042648
\wo subject-specific layer,-0.250979,-0.055527,-0.065257,-0.222657
\wo initial 1x1 conv.,-0.046902,-0.036701,-0.015607,-0.028034
\wo non-residual GLU conv.,-0.005949,-0.197409,-0.065515,-0.00529
\w subj. embedding*,-0.027571,0.036562,0.012686,0.000136
\wo spatial attention,-0.016688,-0.051694,-0.015181,-0.048821
"\wo GELU, \w ReLU",-0.017238,-0.011799,-0.008898,-0.017708


In [87]:
pvalues

dataset,audio_mous,brennan2019,broderick2019,gwilliams2022
row_label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
reference,1.0,1.0,1.0,1.0
\wo clamping brain signal,1.781294e-17,1.629815e-09,4e-06,1.490116e-08
\wo final convs,0.5500196,1.77417e-07,1.9e-05,1.490116e-07
\wo skip connections,7.027166e-13,0.05513315,1.9e-05,1.490116e-08
\wo subject-specific layer,1.7811850000000003e-17,7.964151e-05,4e-06,1.490116e-08
\wo initial 1x1 conv.,1.8381850000000002e-17,4.170742e-05,0.001694,1.490116e-08
\wo non-residual GLU conv.,0.00692265,4.656613e-10,4e-06,0.04098319
\w subj. embedding*,6.841807e-13,1.564149e-05,1.9e-05,0.8779178
\wo spatial attention,3.508152e-08,1.247972e-07,0.000164,1.490116e-08
"\wo GELU, \w ReLU",1.216311e-11,0.04842899,0.007145,1.490116e-08


In [88]:
pvalues_agg_dset

row_label,reference
row_label,Unnamed: 1_level_1
reference,1.0
\wo clamping brain signal,1.9371799999999997e-30
\wo final convs,5.296545e-10
\wo skip connections,9.782553e-21
\wo subject-specific layer,8.523572e-30
\wo initial 1x1 conv.,7.723849e-28
\wo non-residual GLU conv.,4.289443e-14
\w subj. embedding*,0.005663648
\wo spatial attention,7.156011e-21
"\wo GELU, \w ReLU",3.7310640000000004e-17
