In [45]:
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
import pandas as pd
from omegaconf import OmegaConf
import os
import torch
import numpy as np
import matplotlib.pyplot as plt 
from matplotlib.gridspec import GridSpec
from tqdm.notebook import tqdm
import bm
os.chdir(Path(bm.__file__).parent.parent)
from bm import train

%matplotlib inline
%config InlineBackend.figure_format='retina'

In [46]:
output_dir = train.main.dora.dir
print(output_dir)
eval_dir = output_dir / "eval" / "signatures"
sigs_to_eval = [p.name for p in (output_dir / "grids" / "ablation_final").iterdir()]
# sigs_to_eval = [p.name for p in (output_dir / "grids" / "lr_batch_size").iterdir()]
# eval_dir = Path('/private/home/defossez/projs/brainmagick/outputs/eval/signatures_final')
# sigs_to_eval = [p.name for p in eval_dir.iterdir()]
print(sigs_to_eval)
assert output_dir.exists()
# assert grid_dir.exists()
assert eval_dir.exists()


/checkpoint/defossez/brainmagick/experiments
['01cd28fc', 'ec29e81e', '2dd6ddbe', '55751f43', '8e9c33b1', '4438d221', 'fb2cced8', '480110a3', '9e8b929d', 'f33cf48b', 'b476448e', 'ceb23552', 'ca8c38f8', '85f40582', 'ff85bfd6', '81bd9784', '048edc79', '70a0e185', '092d508e', 'a5acd4a3', 'c3b555db', 'b978fc86', '1dd888f0', 'fc475d47', '41831434', 'fa2d2578', '08cf8898', '38449131', 'bb4f4ee9', '666e391f', 'e48c36ea', '6577f78f', '22fb342c', 'd605ce2f', '2aecc316', 'd890f8eb', '493023a0', '5d3de32c', '16d17b9c', '27332012', 'de507ded', '2325aad9', '64c1362f', 'f3cae34e', '49efae33', 'c32f25b6', 'b23c2cfa', '832d9622', '036741ea', '766c6d08', 'c4098073', 'e521c72b', 'ac5b3b57', '3cf51478', 'bea7f00a', 'bd7d06d2', '1bd6de99', '55c4a03c', '0896a4a8', '0d4f43b4', '86811260', 'd9cec30d', '5b38fa38', '4479100b', '73c5dce9', '9ed67c53', 'bd30ebf1', 'ea519a2a', '8d13203d', '2ae3913a', '22d2260a', '6324264b', 'd3399072', 'a5e3ebae', '5c633e70', '0e6262e9', '7f15adbc', 'bd231c7a', '437a096c', 'e5c6b

# Common functions

In [47]:
def load_data_from_sig(sig, level="segment"):
    """
    Load data from solver signature
    - probs (torch tensor): probability on vocab [N, V]
    - vocab (torch tensor): vocab of word hashes [V]
    - words (torch tensor): the word hash for each sample [N]
    - metadata (panda dataframe) of len [N] which contains for each sample:
           'word_hashes', 'word_indices', 'seq_indices',
           'word_strings', 'subject_id', 'recording_id'
    """
    assert level in ["word", "segment"], "level should be 'word' or 'segment'"
    probs = torch.load(eval_dir / sig / f"probs_{level}.pth") 
    vocab = torch.load(eval_dir / sig / f"vocab_{level}.pth") # vocab (hashes)
    metadata = pd.read_csv(eval_dir / sig / "metadata.csv", index_col=0, low_memory=False) 
    
    words = torch.tensor(metadata[f"{level}_hashes"].tolist()).long() # for each sample, the word (hashes)
    assert probs.shape == (len(words), len(vocab)), (probs.shape, len(words), len(vocab))
    assert len(words) == len(metadata)
    metadata["idx"] = range(len(words))
    metadata["word_strings"] = metadata["word_strings"].str.lower()

    return probs, vocab, words, metadata


def get_accuracy_from_probs(probs, row_labels, col_labels, topk=10):
    """
    probs: for each row, the probability distribution over a vocab
    returns the topk accuracy that the topk best predicted labels
    match the row_labels
    Inputs:
        probs: of shape [B, V] probability over vocab, each row sums to 1
        row_labels: of shape [B] true word for each row
        col_labels: [V] word that correspond to each column
        topk: int
    Returns: float scalar, topk accuracy
    """
    assert len(row_labels) == len(probs)
    assert len(col_labels) == probs.shape[1]

    # Extract topk indices
    idx = probs.topk(topk, dim=1).indices

    # Get the corresponding topk labels
    whs = col_labels[idx.view(-1)].reshape(idx.shape)

    # 1 if the labels matches with the targets
    correct = ((whs == row_labels[:, None]).any(1)).float()
    assert len(correct) == len(row_labels)

    # Average across samples
    acc = correct.mean()

    return acc.item()


def eval_acc_one_sig(sig, topks=(1, 5, 10), level="word", add_baselines=True):
    """
    Return accuracy dataframe from one solver signature
    level: whether to return `word` or `segment` level accuracy
    """
    # Load data
    probs, vocab, words, _ = load_data_from_sig(sig, level=level)
#     if level == "segment":
#         print(probs.shape)
        
    # Compute acc
    acc_df = []
    for topk in topks:
        
        # --- Acc ---
        acc = get_accuracy_from_probs(probs, words, vocab, topk=topk)
        
        out = {
            f"acc":acc,
            "topk":topk,
        }
        
        if add_baselines:
        
            # --- Baseline on vocab ---
            # equivalent to : shuffle targets vocab (inf times)
            # equivalent to : output uniform prob on vocab
            # equivalent to : 1/vocab_len
            rand_probs_vocab = torch.ones_like(probs) / len(vocab)
            out["baseline_vocab"] = get_accuracy_from_probs(rand_probs_vocab, words, vocab, topk=topk)

            # --- Baseline on words ---
            # equivalent to : shuffle word targets before aggregating on vocab (inf times)
            # equivalent to : output uniform prob on samples
            # equivalent to : each_word_freq
            check_vocab, counts = torch.unique(words, return_counts=True)
            import pdb
#             assert (check_vocab == vocab).all()
            rand_probs_words = torch.stack([counts/len(words)]*len(probs))
            out["baseline"] = get_accuracy_from_probs(rand_probs_words, words, vocab, topk=topk)

            # Update
            acc_df.append(out)
    acc_df = pd.DataFrame(acc_df)
    return acc_df

def eval_acc(sigs, level="word", add_baselines=True):
    """
    Return accuracy dataframe for multiple sigs 
    level: whether to return word or segment level accuracy
    """
    futures = []
    acc = []
    with ProcessPoolExecutor(20) as pool:
        for sig in sigs:
            future = pool.submit(eval_acc_one_sig, sig, level=level, add_baselines=add_baselines)
            futures.append((sig, future))
        for sig, future in tqdm(futures):
            try:
                acc_sig = future.result()
            except Exception:
                print("ERROR WITH", sig)
                raise
                continue
            acc_sig["sig"] = sig
            acc.append(acc_sig)
    acc = pd.concat(acc)
    return acc

# Load meta dataframe

In [48]:
# Select signatures
valid_sigs = [sig for sig in sigs_to_eval if (eval_dir / sig / "vocab_segment.pth").is_file()]
configs = [OmegaConf.load(eval_dir / sig / "solver_config.yaml") for sig in valid_sigs]
for c, s in zip(configs, valid_sigs):
    if not hasattr(c.dset, 'features'):
        c.dset.features = c.dset.forcings
        c.dset.features_params = c.dset.forcings_params
        
print(set(sigs_to_eval) - set(valid_sigs))
run_df = pd.DataFrame({
    "sig":valid_sigs,
})
run_df["dataset"] = ["-".join(conf.dset.selections) for conf in configs]
run_df["seed"] = [conf.seed for conf in configs]
run_df["forcings"] = ["-".join(conf.dset.features) for conf in configs]
run_df["loss"] = [conf.optim.loss for conf in configs]
run_df["is_random"] = [conf.test.wer_random for conf in configs]
run_df["max_scale"] = [conf.norm.max_scale for conf in configs]
run_df["n_mels"] = [conf.dset.features_params.MelSpectrum.n_mels for conf in configs]
run_df["deepmel"] = [bool(conf.clip.arch) for conf in configs]
run_df["ft"] = [conf.optim.epochs == 1 and not conf.test.wer_random for conf in configs]
run_df["random"] = [conf.test.wer_random for conf in configs]

run_df["batch_size"] = [conf.optim.batch_size for conf in configs]
run_df["lr"] = [conf.optim.lr for conf in configs]
run_df["autorej"] = [conf.dset.autoreject for conf in configs]
run_df["n_rec"] = [conf.dset.n_recordings for conf in configs]
# run_df["ft"] = [conf.optim.lr == 0 for conf in configs]
run_df["dropout"] = [conf.simpleconv.merger_dropout > 0 for conf in configs]
run_df["gelu"] = [bool(conf.simpleconv.gelu) for conf in configs]
run_df["skip"] = [bool(conf.simpleconv.skip) for conf in configs]
run_df["initial"] = [bool(conf.simpleconv.initial_linear) for conf in configs]
run_df["complex"] = [bool(conf.simpleconv.complex_out) for conf in configs]
run_df["subject_lay"] = [bool(conf.simpleconv.subject_layers) for conf in configs]
run_df["subject_emb"] = [bool(conf.simpleconv.subject_dim) for conf in configs]
run_df["attention"] = [bool(conf.simpleconv.merger) for conf in configs]
run_df["glu"] = [bool(conf.simpleconv.glu) for conf in configs]
run_df["depth"] = [conf.simpleconv.depth for conf in configs]
run_df["offset_meg"] = [conf.task.offset_meg_ms for conf in configs]
run_df = run_df[run_df.loss == "clip"]
# run_df = run_df[run_df.dataset == "gwilliams2022"]



set()


In [49]:
def get_name(row):
    if not row.dropout:
        return r"\wo spatial attention dropout"
    if not row.gelu:
        return r"\wo GELU, \w ReLU"
    if not row.skip:
        return r"\wo skip connections"
    if not row.initial:
        return r"\wo initial 1x1 conv."
    if not row.complex:
        return r"\wo final convs"
    if not row.attention:
        return r"\wo spatial attention"
    if not row.glu:
        return r"\wo non-residual GLU conv."
    if not row.subject_lay:
        if row.subject_emb:
            return r"\w subj. embedding*"
        else:
            return r"\wo subject-specific layer"
    if row.depth == 5:
        return r"less deep"
    elif row.max_scale != 20:
        if row.max_scale == 100:
            return "\w clamp=100"
        else:
            return "\wo clamping brain signal"
    return "Our model"
    
        
run_df['name'] = run_df.apply(get_name, axis=1)

# Accuracy table

In [50]:
%%time
acc_df = eval_acc(run_df["sig"].values, level="segment")
# acc_df = pd.merge(acc_df, on=["sig", "topk"], how="outer")
def dset_order(name):
    return name.map({
       'audio_mous': 0, 
       'gwilliams2022': 1,
       'broderick2019': 2,
       'brennan2019': 3,
    })



  0%|          | 0/144 [00:00<?, ?it/s]

CPU times: user 277 ms, sys: 230 ms, total: 507 ms
Wall time: 26.6 s


In [51]:
keys = ['dataset', 'name']
acc_table = pd.merge(acc_df, run_df, on="sig", how="left")
acc_table = acc_table.sort_values(["dataset"], key=dset_order);
acc_table = acc_table.query("topk==10").sort_values(keys).groupby(keys)["acc"].agg(["mean", "std"])
key = "acc"
acc_table["str_acc"] = (100 * acc_table["mean"]).round(1).astype(str) + r" PM " + (100 * acc_table["std"]).round(2).astype(str)
print(acc_table)
    

                                                 mean       std        str_acc
dataset       name                                                            
audio_mous    Our model                      0.673448  0.003624   67.3 PM 0.36
              \w clamp=100                   0.669298  0.007908   66.9 PM 0.79
              \w subj. embedding*            0.645722  0.003076   64.6 PM 0.31
              \wo GELU, \w ReLU              0.656240  0.005631   65.6 PM 0.56
              \wo clamping brain signal      0.015385  0.003648    1.5 PM 0.36
              \wo final convs                0.672309  0.003212   67.2 PM 0.32
              \wo initial 1x1 conv.          0.626807  0.007912   62.7 PM 0.79
              \wo non-residual GLU conv.     0.667211  0.002295   66.7 PM 0.23
              \wo skip connections           0.652577  0.002846   65.3 PM 0.28
              \wo spatial attention          0.656174  0.004009    65.6 PM 0.4
              \wo spatial attention dropout  0.67313

In [52]:
def convert(x):
    if isinstance(x, float):
        print(x)
    return float(x.split(" ")[0])
    
toplot = acc_table.reset_index()
index = list(keys)
index.remove('dataset')
toplot =  pd.pivot_table(toplot, values=["str_acc"], columns="dataset", index=index, aggfunc="first")

toplot[('str_acc', "mean_dataset")] = toplot.applymap(convert).mean(axis=1).round(1)
toplot= toplot.sort_values([('str_acc', 'mean_dataset')], ascending=False)
dsets = ['audio_mous', 'gwilliams2022', 'broderick2019', 'brennan2019']
BASE_ACC = 43.3
toplot[('str_acc', 'delta')] = toplot[('str_acc', "mean_dataset")] - BASE_ACC
toplot[('str_acc', 'p_value')] = ''  # done in another notebook.
extra = ['delta', 'p_value']

toplot = toplot[[('str_acc', dset) 
                 for dset in dsets + ['mean_dataset'] +  extra]]
toplot

Unnamed: 0_level_0,str_acc,str_acc,str_acc,str_acc,str_acc,str_acc,str_acc
dataset,audio_mous,gwilliams2022,broderick2019,brennan2019,mean_dataset,delta,p_value
name,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
\w subj. embedding*,64.6 PM 0.31,66.7 PM 0.25,14.7 PM 0.45,29.4 PM 0.59,43.8,0.5,
\w clamp=100,66.9 PM 0.79,66.5 PM 0.36,13.3 PM 0.24,27.1 PM 2.63,43.4,0.1,
Our model,67.3 PM 0.36,66.7 PM 0.15,13.4 PM 0.39,25.7 PM 2.93,43.3,0.0,
\wo spatial attention dropout,67.3 PM 0.12,65.2 PM 0.09,12.3 PM 1.14,26.8 PM 0.68,42.9,-0.4,
"\wo GELU, \w ReLU",65.6 PM 0.56,64.9 PM 1.29,12.5 PM 0.22,24.6 PM 2.12,41.9,-1.4,
\wo skip connections,65.3 PM 0.28,62.4 PM 0.3,11.0 PM 1.33,24.3 PM 2.67,40.8,-2.5,
\wo final convs,67.2 PM 0.32,65.1 PM 0.92,11.2 PM 0.89,19.0 PM 4.4,40.6,-2.7,
\wo initial 1x1 conv.,62.7 PM 0.79,64.0 PM 0.62,11.9 PM 0.52,22.1 PM 1.92,40.2,-3.1,
\wo spatial attention,65.6 PM 0.4,61.9 PM 0.35,11.9 PM 0.42,20.6 PM 2.15,40.0,-3.3,
\wo non-residual GLU conv.,66.7 PM 0.23,66.3 PM 0.17,6.9 PM 5.13,6.0 PM 0.17,36.5,-6.8,


In [54]:
print(toplot.to_latex(index=True).replace('PM', r'$\pm$'))

\begin{tabular}{lllllrrl}
\toprule
{} & \multicolumn{7}{l}{str\_acc} \\
dataset &    audio\_mous &  gwilliams2022 & broderick2019 &   brennan2019 & mean\_dataset & delta & p\_value \\
name                          &               &                &               &               &              &       &         \\
\midrule
\textbackslash w subj. embedding*           &  64.6 $\pm$ 0.31 &   66.7 $\pm$ 0.25 &  14.7 $\pm$ 0.45 &  29.4 $\pm$ 0.59 &         43.8 &   0.5 &         \\
\textbackslash w clamp=100                  &  66.9 $\pm$ 0.79 &   66.5 $\pm$ 0.36 &  13.3 $\pm$ 0.24 &  27.1 $\pm$ 2.63 &         43.4 &   0.1 &         \\
Our model                     &  67.3 $\pm$ 0.36 &   66.7 $\pm$ 0.15 &  13.4 $\pm$ 0.39 &  25.7 $\pm$ 2.93 &         43.3 &   0.0 &         \\
\textbackslash wo spatial attention dropout &  67.3 $\pm$ 0.12 &   65.2 $\pm$ 0.09 &  12.3 $\pm$ 1.14 &  26.8 $\pm$ 0.68 &         42.9 &  -0.4 &         \\
\textbackslash wo GELU, \textbackslash w ReLU             &  6

  print(toplot.to_latex(index=True).replace('PM', r'$\pm$'))
