In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
import itertools
from pathlib import Path
import yaml

from omegaconf import OmegaConf
import pandas as pd
from matplotlib import transforms
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.analysis.trf import coefs_to_df
from src.encoding.ecog import timit as timit_encoding, \
     AlignedECoGDataset, ContrastiveModelSnapshot, epoch_by_state_space
from src.utils.timit import get_word_metadata

In [5]:
import sys
sys.path.append("/userdata/jgauthier/projects/neural-foundation-models")

In [6]:
dataset = "timit-no_repeats"
subject = "EC212"
ttest_results_path = f"outputs/encoder_comparison_across_subjects/{dataset}/ttest.csv"
scores_path = f"outputs/encoder_comparison_across_subjects/{dataset}/scores.csv"
unique_variance_path = f"outputs/encoder_unique_variance/{dataset}/baseline/{subject}/unique_variance.csv"
contrasts_path = f"outputs/electrode_contrast/{dataset}/contrasts.csv"

baseline_model = "baseline"
encoder_dirs = list(Path("outputs/encoders").glob(f"{dataset}/*/{subject}"))

output_dir = "."

pval_threshold = 1e-3

In [7]:
contrasts_df = pd.read_csv(contrasts_path, index_col=["contrast_method", "contrast", "subject"])

In [None]:
scores_df = pd.read_csv(scores_path, index_col=["dataset", "subject", "model2", "model1"]).loc[dataset, subject]
study_models = sorted(scores_df.index.get_level_values("model2").unique())

scores_df

In [None]:
ttest_df = pd.read_csv(ttest_results_path, index_col=["dataset", "subject", "model2", "model1", "output_dim"]) \
    .loc[dataset].loc[subject].loc[study_models]
ttest_df["log_pval"] = np.log10(ttest_df["pval"])
ttest_df

In [None]:
ttest_filtered_df = ttest_df.dropna().sort_values("pval", ascending=False) \
    .groupby(["model2", "output_dim"]).first()
ttest_filtered_df = ttest_filtered_df.loc[ttest_filtered_df["pval"] < pval_threshold]
ttest_filtered_df

In [None]:
unique_variance_df = pd.read_csv(unique_variance_path, index_col=["dropped_feature", "fold", "output_dim"])
# ^ this is actually not unique variance, but the inputs to the calculation. let's do it:
unique_variance = unique_variance_df.loc[np.nan].unique_variance_score - unique_variance_df[~unique_variance_df.index.get_level_values("dropped_feature").isna()].unique_variance_score
unique_variance

In [12]:
encoder_dirs = [Path(p) for p in encoder_dirs]
encoder_dirs = {encoder_dir.parent.name: encoder_dir for encoder_dir in encoder_dirs
                if encoder_dir.parent.name in [baseline_model] + study_models}
encoders = {model_name: torch.load(encoder_dir / "model.pkl")
            for model_name, encoder_dir in encoder_dirs.items()}
encoder_names = sorted(encoders.keys())

In [13]:
baseline_scores = pd.read_csv(encoder_dirs[baseline_model] / "scores.csv")

In [14]:
# Just need a random config in order to extract relevant paths and get outfile
sample_model_path = encoder_dirs[next(iter(study_models))]
with (sample_model_path / ".hydra" / "config.yaml").open() as f:
    model_config = OmegaConf.create(yaml.safe_load(f))
out = timit_encoding.prepare_out_file(model_config, next(iter(model_config.data)))

In [15]:
snapshot = ContrastiveModelSnapshot.from_config(model_config, next(iter(model_config.feature_sets.model_features.values())))

In [None]:
aligned = AlignedECoGDataset(snapshot, out)

In [None]:
electrode_df = pd.read_csv(next(iter(encoder_dirs.values())) / "electrodes.csv")
electrode_df

In [18]:
num_fit_electrodes = next(iter(encoders.values())).coef_.shape[0]
electrode_names = electrode_df.head(num_fit_electrodes).electrode_name

In [None]:
coef_dfs = {model_name: coefs_to_df(torch.load(encoder_dir / "coefs.pkl"),
                                    encoders[model_name].feature_names,
                                    electrode_names,
                                    encoders[model_name].sfreq)
            for model_name, encoder_dir in tqdm(encoder_dirs.items())}
coef_df = pd.concat(coef_dfs, names=["model"]).droplevel(1)
coef_df

In [None]:
contrasts_df = contrasts_df.loc[contrasts_df.output_dim <= num_fit_electrodes]
contrasts_df

In [None]:
all_trf_features = coef_df.feature.unique()
all_trf_features

### Compute epoched HGA

In [26]:
# hack together a new state space spec for sentence onset
# nb state space bounds are inclusive, so we need to subtract 1 from the end of each bound
trial_spec = StateSpaceAnalysisSpec(
    aligned.total_num_frames,
    ["trial"],
    [sorted([(start, end - 1) for start, end in aligned.name_to_frame_bounds.values()])],
)
aligned._snapshot.all_state_spaces["trial"] = trial_spec

In [None]:
trial_epochs = epoch_by_state_space(
    aligned, "trial",
    epoch_window=(-0.1, 1.),
    baseline_window=(-0.1, 0.),
    return_df=True)
trial_epochs.head(3)

In [29]:
assert trial_epochs.groupby(["epoch_idx", "electrode_idx", "epoch_sample"]).value.count().max() == 1

In [None]:
word_epochs = epoch_by_state_space(
    aligned, "word",
    epoch_window=(-0.1, 0.6),
    baseline_window=(-0.1, 0.),
    return_df=True)

In [None]:
word_metadata = get_word_metadata(snapshot.all_state_spaces["word"])

In [32]:
# Merge in word metadata
word_epochs = pd.merge(
    word_epochs, word_metadata,
    left_on=["epoch_label", "instance_idx"],
    right_on=["label", "instance_idx"],
    how="left",
    validate="many_to_one")

In [None]:
word_epochs.head(3)

In [None]:
word_offset_epochs = epoch_by_state_space(
    aligned, "word",
    align_to="offset",
    epoch_window=(-0.6, 0.1),
    baseline_window=(0., 0.1),
    return_df=True)
# Merge in word metadata
word_offset_epochs = pd.merge(
    word_offset_epochs, word_metadata,
    left_on=["epoch_label", "instance_idx"],
    right_on=["label", "instance_idx"],
    how="left",
    validate="many_to_one")

word_offset_epochs.head(3)

In [None]:
syllable_epochs = epoch_by_state_space(
    aligned, "syllable",
    epoch_window=(-0.1, 0.3),
    baseline_window=(-0.1, 0.),
    return_df=True)

### Plotting prep

In [36]:
model_color_norm = plt.Normalize(0, len(encoder_names))
model_color_mapper = plt.colormaps["tab10"]
get_model_color = lambda model_name: model_color_mapper(model_color_norm(encoder_names.index(model_name)))

### Correspondences between electrodes significant under different models

In [None]:
log_pvals = ttest_filtered_df.pivot_table(values="log_pval", index="model2", columns="output_dim").fillna(0)
log_pvals.index.name = "model_name"
log_pvals

In [None]:
g = sns.clustermap(log_pvals, vmax=0, xticklabels=1, figsize=(20, 16))
g.ax_heatmap.set_xlabel("Electrode")
g.ax_heatmap.set_ylabel("Model name")
g.ax_heatmap.set_xticks([])
plt.setp(g.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)

In [None]:
log_pvals_flat = log_pvals.copy()
log_pvals_flat = log_pvals_flat.reset_index().melt(id_vars="model_name", value_name="log_pval")
log_pvals_flat = log_pvals_flat[log_pvals_flat.log_pval < 0]
log_pvals_flat = pd.merge(log_pvals_flat, electrode_df, left_on=["output_dim"], right_on=["electrode_idx"])
log_pvals_flat = log_pvals_flat.set_index("model_name")
log_pvals_flat

In [None]:
from matplotlib_venn import venn3

venn_labels = {
    "Word": "ph-ls-word_broad-hinge-w2v2_8-l2norm",
    "Word discrim2": "ph-ls-word_broad-hinge-w2v2_8-discrim2-l2norm",
    "Phoneme": "phoneme-w2v2_8-l2norm"
}

f, ax = plt.subplots(figsize=(10, 15))
venn3(
    [set(log_pvals_flat.loc[model_name].electrode_idx)
     if model_name in log_pvals_flat.index
     else set()
     for model_name in venn_labels.values()],
    set_labels=[label if model_name in log_pvals_flat.index else ""
                for label, model_name in venn_labels.items()],
)

In [None]:
rois = sorted(log_pvals_flat.roi.unique())
n_cols = 3
n_rows = int(np.ceil(len(rois) / n_cols))

f, axes = plt.subplots(n_rows, n_cols, figsize=(10 * n_cols, 10 * n_rows))
for roi, ax in zip(rois, axes.flat):
    plot_data = [set(log_pvals_flat.loc[model_name].query(f"roi == '{roi}'").electrode_idx)
                 if model_name in log_pvals_flat.index
                 else set()
                 for model_name in venn_labels.values()]
    # skip labels for empty sets
    plot_labels = [label if len(data) > 0 else ""
                   for label, data in zip(venn_labels.keys(), plot_data)]

    venn3(plot_data, set_labels=plot_labels, ax=ax)
    ax.set_title(roi)

## Colocation

### Colocation of baseline predictiveness and model improvement

In [None]:
score_relationship = scores_df.assign(model=scores_df.model.replace({model_name: "full_model" for model_name in set(scores_df.model) - {baseline_model}})) \
    .xs(baseline_model, level="model1") \
    .reset_index().pivot(index=["model2", "output_dim", "fold"], columns="model", values="score")
score_relationship

In [None]:
g = sns.lmplot(data=score_relationship.reset_index(), x="baseline", y="full_model", col="model2", col_wrap=3,
               facet_kws=dict(sharex=False, sharey=False))

ax_min = 0.
ax_max = score_relationship.max().max()
for ax in g.axes.ravel():
    ax.plot([0, 1], [0, 1], color="black", linestyle="--", alpha=0.4)
    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)
    ax.set_xlabel("Baseline encoder $r^2$")
    ax.set_ylabel("Full model $r^2$")

### Colocation study by $p$-value

In [None]:
# get least-significant p-value result per model -- electrode
electrode_pvals = ttest_df.loc[(slice(None), baseline_model), "log_pval"].groupby(["model2", "output_dim"]).max()
# insert zero pvals for missing model--electrode combinations
electrode_pvals = electrode_pvals.reindex(pd.MultiIndex.from_product([study_models, electrode_names.index], names=["model2", "output_dim"])) \
    .fillna(0.)
electrode_pvals

In [None]:
r2_comparison = scores_df.xs(baseline_model, level="model1")
r2_comparison.loc[r2_comparison.model != baseline_model, "model"] = "full_model"
r2_comparison = r2_comparison.reset_index().pivot_table(index=["model2", "output_dim", "fold"], columns="model", values="score")
r2_comparison["absolute_improvement"] = r2_comparison["full_model"] - r2_comparison[baseline_model].combine(0, max)
r2_comparison.loc[r2_comparison["absolute_improvement"] < 0, "absolute_improvement"] = 0
r2_comparison = r2_comparison.sort_values("absolute_improvement", ascending=False)
r2_comparison

In [None]:
coef_df.loc[study_models]

In [191]:
plot_coef_df = coef_df.loc[study_models].reset_index()
# filter to electrodes of interest
plot_coef_df = plot_coef_df[plot_coef_df.output_dim.isin([152])]
# name model embedding coefficients according to model
model_coefs = plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding")]
plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding"), "feature"] = \
    model_coefs.model.str.cat(model_coefs.feature, sep="_")

# filter to features of interest
features = sorted([f for f in coef_df.feature.unique() if not f.startswith("model_embedding")])
plot_coef_df_features = plot_coef_df[plot_coef_df.feature.str.startswith(tuple(features))]
plot_coef_df_features = plot_coef_df_features[["fold", "feature", "output_dim", "time", "coef"]]
plot_coef_df_features["type"] = "basic_feature"

# add computed feature norms for embeddings
plot_coef_df_embeddings = plot_coef_df[plot_coef_df.feature.str.contains("model_embedding")]
plot_coef_df_embeddings = plot_coef_df_embeddings.groupby(["fold", "model", "output_dim", "time"]) \
    .coef.apply(lambda xs: xs.abs().sum()).reset_index() \
    .rename(columns={"model": "feature"}).assign(type="model_embedding")

In [None]:
panel_contrast = "word_dominant"
panel_contrast_df = contrasts_df.xs(panel_contrast, level="contrast").xs(subject, level="subject") \
    .reset_index().set_index(["outcome", "output_dim"])

panel_contrast_electrodes = {
    outcome: panel_contrast_df.loc[outcome].index.unique()
    for outcome in panel_contrast_df.index.get_level_values("outcome").unique()
    if outcome not in [np.nan, None]
}

panel_contrast_electrodes

In [227]:
def render_electrode_panel(
        electrode, model_embeddings=None, features=None,
        trial_epoch_kwargs=None,
        word_epoch_kwargs=None,
        word_epoch2_kwargs=None,
        smoke_test=False):
    figure = plt.figure(figsize=(48, 24) if not smoke_test else (10, 8))
    gs = gridspec.GridSpec(3, 5, figure=figure,
                           width_ratios=[3, 3, 2, 2, 1.5], hspace=0.25, wspace=0.25)
    electrodes = [electrode]

    errorbar = "se" if not smoke_test else None

    if model_embeddings is None:
        model_embeddings = sorted([m for m in electrode_pvals.index.get_level_values("model2").unique() if m != "baseline"])
    if features is None:
        features = sorted([f for f in coef_df.feature.unique() if not f.startswith("model_embedding")])

    ##### plot electrode t-values and feature norms
    print("performance plots")

    tval_ax = figure.add_subplot(gs[0, 0])
    tval_ax.set_title("Improvement log $p$-values\nby model embedding")
    tval_ax.axvline(np.log10(pval_threshold), color="black", linestyle="--", linewidth=2)
    r2_ax = figure.add_subplot(gs[0, 1])
    r2_ax.set_title("Improvement $r^2$\nby model embedding")
    feature_norm_ax = figure.add_subplot(gs[0, 2])
    feature_norm_ax.set_title("Unique variance")

    tval_df = electrode_pvals.loc[model_embeddings].loc[(slice(None), electrodes)]
    tval_df_order = tval_df.sort_values(ascending=True).index.get_level_values("model2")
    sns.barplot(data=tval_df.reset_index(), x="log_pval", y="model2",
                ax=tval_ax, order=tval_df_order)
    for ticklabel in tval_ax.get_yticklabels():
        if ticklabel.get_text() in model_embeddings:
            ticklabel.set_fontweight("bold")

    r2_df = r2_comparison.loc[study_models]
    # r2_order = pdf["absolute_improvement"].groupby("model2").mean().sort_values(ascending=False).index
    # share order + ticks with p-value plot
    r2_order = tval_df_order
    sns.barplot(data=r2_df.reset_index(), x="absolute_improvement", y="model2", order=r2_order, ax=r2_ax)
    r2_ax.set_yticklabels([])
    r2_ax.set_ylabel(None)

    unique_variance_df = unique_variance.loc[(slice(None), electrodes)].reset_index().rename(columns={"dropped_feature": "feature"})
    unique_variance_means = unique_variance_df.groupby("feature").unique_variance_score.mean()
    unique_variance_df_order = unique_variance_means[unique_variance_means >= 0].sort_values(ascending=False).index
    sns.barplot(data=unique_variance_df,
                x="unique_variance_score", y="feature",
                ax=feature_norm_ax, order=unique_variance_df_order)
    feature_norm_ax.set_xlim((0, feature_norm_ax.get_xlim()[1]))
    for ticklabel in feature_norm_ax.get_yticklabels():
        if ticklabel.get_text().startswith(tuple(features)):
            ticklabel.set_fontweight("bold")

    #####
    print("coef prep")

    # prepare single coefficient df
    plot_coef_df = coef_df.loc[model_embeddings].reset_index()
    # filter to electrodes of interest
    plot_coef_df = plot_coef_df[plot_coef_df.output_dim.isin(electrodes)]
    # name model embedding coefficients according to model
    model_coefs = plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding")]
    plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding"), "feature"] = \
        model_coefs.model.str.cat(model_coefs.feature, sep="_")
    
    # filter to features of interest
    plot_coef_df_features = plot_coef_df[plot_coef_df.feature.str.startswith(tuple(features))]
    plot_coef_df_features = plot_coef_df_features[["fold", "feature", "output_dim", "time", "coef"]]
    plot_coef_df_features["type"] = "basic_feature"
    # add computed feature norms for embeddings
    plot_coef_df_embeddings = plot_coef_df[plot_coef_df.feature.str.contains("model_embedding")]
    plot_coef_df_embeddings = plot_coef_df_embeddings.groupby(["fold", "model", "output_dim", "time"]) \
        .coef.apply(lambda xs: xs.abs().sum()).reset_index() \
        .rename(columns={"model": "feature"}).assign(type="model_embedding")
    
    #####
    # coef_line_ax = figure.add_subplot(gs[1, :])
    # sns.lineplot(data=plot_coef_subset_df, x="time", y="coef", hue="feature", style="type", ax=coef_line_ax)

    #####

    print("coef plots")

    feature_coef_heatmap_ax = figure.add_subplot(gs[1, :2])
    plot_coef_heatmap_df = plot_coef_df_features.pivot_table(
        index="feature", columns="time", values="coef", aggfunc="mean")
    plot_coef_heatmap_df = plot_coef_heatmap_df.loc[sorted(plot_coef_df_features.feature.unique())]
    sns.heatmap(plot_coef_heatmap_df, ax=feature_coef_heatmap_ax, cmap="RdBu", center=0, yticklabels=True)

    model_coef_heatmap_ax = figure.add_subplot(gs[2, :2])
    plot_coef_heatmap_df = plot_coef_df_embeddings.pivot_table(
        index="feature", columns="time", values="coef", aggfunc="mean")
    # # order by decreasing t-value
    # plot_coef_heatmap_df = plot_coef_heatmap_df.loc[[model for model in tval_df_order if model in plot_coef_heatmap_df.index]]
    # order by name
    plot_coef_heatmap_df = plot_coef_heatmap_df.loc[sorted(plot_coef_df_embeddings.feature.unique())]
    sns.heatmap(plot_coef_heatmap_df, ax=model_coef_heatmap_ax, cmap="RdBu", center=0, yticklabels=True)

    ##### contrast outcomes

    outcome_int_map = {"positive": 2, "balanced": 1, "none": 7, "negative": 0}
    outcome_heatmap_data = panel_contrast_df.xs(electrode, level="output_dim").reset_index() \
        .set_index("contrast_method").outcome.sort_index().fillna("none")
    outcome_ax = figure.add_subplot(gs[0, 4])
    sns.heatmap(outcome_heatmap_data.map(outcome_int_map).to_frame(),
        cmap="Set1", vmin=min(outcome_int_map.values()), vmax=max(outcome_int_map.values()),
        annot=outcome_heatmap_data.to_frame(), fmt="", cbar=False, ax=outcome_ax)

    #####

    print("epochs plots")

    trial_epochs_ax = figure.add_subplot(gs[0, 3])
    trial_epochs_ax.set_title("Trial ERP")
    trial_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_trial_epochs = trial_epochs[(trial_epochs.electrode_idx == electrode)]
    sns.lineplot(data=plot_trial_epochs, x="epoch_time", y="value", ax=trial_epochs_ax,
                 errorbar=errorbar,
                 **(trial_epoch_kwargs or {}))

    word_epochs_ax = figure.add_subplot(gs[1, 2])
    word_epochs_ax.set_title("Word ERP")
    word_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_word_epochs = word_epochs[word_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_epochs, x="epoch_time", y="value", ax=word_epochs_ax,
                 errorbar=errorbar,
                 legend=False,  # legend is shared with offset plot
                 **(word_epoch_kwargs or {}))
    
    word_epochs2_ax = figure.add_subplot(gs[2, 2])
    word_epochs2_ax.set_title("Word ERP")
    word_epochs2_ax.axvline(0, color="gray", linestyle="--")
    plot_word_epochs2 = word_epochs[word_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_epochs2, x="epoch_time", y="value", ax=word_epochs2_ax,
                 errorbar=errorbar,
                 legend=False,  # legend is shared with offset plot
                 **(word_epoch2_kwargs or {}))

    # syllable_epochs_ax = figure.add_subplot(gs[0, 3])
    # syllable_epochs_ax.set_title("Syllable ERP")
    # syllable_epochs_ax.axvline(0, color="gray", linestyle="--")
    # plot_syllable_epochs = syllable_epochs[syllable_epochs.electrode_idx == electrode]
    # sns.lineplot(data=plot_syllable_epochs, x="epoch_time", y="value", ax=syllable_epochs_ax,
    #              errorbar=errorbar)

    word_offset_epochs_ax = figure.add_subplot(gs[1, 3])
    word_offset_epochs_ax.set_title("Word offset ERP")
    word_offset_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_word_offset_epochs = word_offset_epochs[word_offset_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_offset_epochs, x="epoch_time", y="value", ax=word_offset_epochs_ax,
                 errorbar=errorbar, **(word_epoch_kwargs or {}))
    epochs1_ax = figure.add_subplot(gs[1, 4])
    epochs1_ax.axis("off")
    epochs1_ax.legend(*word_offset_epochs_ax.get_legend_handles_labels(), loc="upper left") \
                      .set_title(word_offset_epochs_ax.get_legend().get_title().get_text())
    word_offset_epochs_ax.get_legend().remove()

    word_offset_epoch2_ax = figure.add_subplot(gs[2, 3])
    word_offset_epoch2_ax.set_title("Word offset ERP")
    word_offset_epoch2_ax.axvline(0, color="gray", linestyle="--")
    plot_word_offset_epoch2 = word_offset_epochs[word_offset_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_offset_epoch2, x="epoch_time", y="value", ax=word_offset_epoch2_ax,
                 errorbar=errorbar, **(word_epoch2_kwargs or {}))
    epochs2_ax = figure.add_subplot(gs[2, 4])
    epochs2_ax.axis("off")
    epochs2_ax.legend(*word_offset_epoch2_ax.get_legend_handles_labels(), loc="upper left") \
                      .set_title(word_offset_epoch2_ax.get_legend().get_title().get_text())
    word_offset_epoch2_ax.get_legend().remove()

    plt.suptitle(f"Electrode {electrode} study")
    
    return plot_coef_df

In [None]:
smoke_test = False
for electrode in tqdm(panel_contrast_electrodes["positive"]):
    print(electrode)
    with plt.rc_context(rc={"font.size": 24 if not smoke_test else 12}):
        render_electrode_panel(
            electrode, smoke_test=smoke_test,
            model_embeddings=study_models,
            word_epoch_kwargs=dict(hue="monosyllabic"),
            word_epoch2_kwargs=dict(hue="word_frequency_quantile"))
        f = plt.gcf()
        f.savefig(f"{output_dir}/electrode_panel-{subject}-{electrode}.png")
        plt.close()