In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import dataclass
import itertools
from pathlib import Path
import yaml

from omegaconf import OmegaConf
import pandas as pd
from matplotlib import transforms
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec
from src.analysis.trf import coefs_to_df
from src.encoding.ecog.timit import OutFile
from src.encoding.ecog import timit as timit_encoding, \
     AlignedECoGDataset, ContrastiveModelSnapshot, epoch_by_state_space
from src.utils.timit import get_word_metadata

In [None]:
dataset = "timit"
subject = "EC196"
study_models = [
    "random8",
    "phoneme",
    "biphone_pred",
    "biphone_recon",
    "next_phoneme",
    "syllable",
    "word_broad-aniso2-w2v2_8",
]
ttest_results_path = f"outputs/encoder_comparison_across_subjects/{dataset}/ttest.csv"
scores_path = f"outputs/encoder_comparison_across_subjects/{dataset}/scores.csv"
unique_variance_path = f"outputs/encoder_unique_variance/{dataset}/baseline/{subject}/unique_variance.csv"

encoder_dirs = list(Path("outputs/encoders").glob(f"{dataset}/*/{subject}"))

output_dir = "."

pval_threshold = 1e-3

In [None]:
scores_df = pd.read_csv(scores_path, index_col=["dataset", "subject", "model2", "model1"]).loc[dataset, subject]
if study_models is None:
    study_models = sorted(scores_df.index.get_level_values("model2").unique())
else:
    scores_df = scores_df.loc[scores_df.index.get_level_values("model2").isin(study_models)]
scores_df

In [None]:
ttest_df = pd.read_csv(ttest_results_path, index_col=["dataset", "subject", "model2", "model1", "output_dim"]) \
    .loc[dataset].loc[subject].loc[study_models]
ttest_df["log_pval"] = np.log10(ttest_df["pval"])
ttest_df

In [None]:
ttest_filtered_df = ttest_df.dropna().sort_values("pval", ascending=False) \
    .groupby(["model2", "output_dim"]).first()
ttest_filtered_df = ttest_filtered_df.loc[ttest_filtered_df["pval"] < pval_threshold]
ttest_filtered_df

In [None]:
unique_variance_df = pd.read_csv(unique_variance_path, index_col=["dropped_feature", "fold", "output_dim"])
# ^ this is actually not unique variance, but the inputs to the calculation. let's do it:
unique_variance = unique_variance_df.loc[np.nan].unique_variance_score - unique_variance_df[~unique_variance_df.index.get_level_values("dropped_feature").isna()].unique_variance_score
unique_variance

In [None]:
encoder_dirs = [Path(p) for p in encoder_dirs]
encoder_dirs = {encoder_dir.parent.name: encoder_dir for encoder_dir in encoder_dirs
                if encoder_dir.parent.name in ["baseline"] + study_models}
encoders = {model_name: torch.load(encoder_dir / "model.pkl")
            for model_name, encoder_dir in encoder_dirs.items()}
encoder_names = sorted(encoders.keys())

In [None]:
baseline_scores = pd.read_csv(encoder_dirs["baseline"] / "scores.csv")

In [None]:
# Just need a random config in order to extract relevant paths and get outfile
sample_model_path = encoder_dirs["phoneme"]
with (sample_model_path / ".hydra" / "config.yaml").open() as f:
    model_config = OmegaConf.create(yaml.safe_load(f))
out = timit_encoding.prepare_out_file(model_config, next(iter(model_config.data)))

In [None]:
snapshot = ContrastiveModelSnapshot.from_config(model_config, next(iter(model_config.feature_sets.model_features.values())))

In [None]:
aligned = AlignedECoGDataset(snapshot, out)

In [None]:
electrode_df = pd.read_csv(next(iter(encoder_dirs.values())) / "electrodes.csv")
electrode_df

In [None]:
num_fit_electrodes = next(iter(encoders.values())).coef_.shape[0]
electrode_names = electrode_df.head(num_fit_electrodes).electrode_name
coef_dfs = {model_name: coefs_to_df(torch.load(encoder_dir / "coefs.pkl"),
                                    encoders[model_name].feature_names,
                                    electrode_names,
                                    encoders[model_name].sfreq)
            for model_name, encoder_dir in tqdm(encoder_dirs.items())}
coef_df = pd.concat(coef_dfs, names=["model"]).droplevel(1)
coef_df

In [None]:
all_trf_features = coef_df.feature.unique()
all_trf_features

### Compute epoched HGA

In [None]:
# hack together a new state space spec for sentence onset
# nb state space bounds are inclusive, so we need to subtract 1 from the end of each bound
trial_spec = StateSpaceAnalysisSpec(
    aligned.total_num_frames,
    ["trial"],
    [sorted([(start, end - 1) for start, end in aligned.name_to_frame_bounds.values()])],
)
aligned._snapshot.all_state_spaces["trial"] = trial_spec

In [None]:
trial_epochs = epoch_by_state_space(
    aligned, "trial",
    epoch_window=(-0.1, 1.),
    baseline_window=(-0.1, 0.),
    return_df=True)
trial_epochs.head(3)

In [None]:
assert trial_epochs.groupby(["epoch_idx", "electrode_idx", "epoch_sample"]).value.count().max() == 1

In [None]:
word_epochs = epoch_by_state_space(
    aligned, "word",
    epoch_window=(-0.1, 0.6),
    baseline_window=(-0.1, 0.),
    return_df=True)

In [None]:
word_metadata = get_word_metadata(snapshot.all_state_spaces["word"])

In [None]:
# Merge in word metadata
word_epochs = pd.merge(
    word_epochs, word_metadata,
    left_on=["epoch_label", "instance_idx"],
    right_on=["label", "instance_idx"],
    how="left",
    validate="many_to_one")

In [None]:
word_epochs.head(3)

In [None]:
word_offset_epochs = epoch_by_state_space(
    aligned, "word",
    align_to="offset",
    epoch_window=(-0.6, 0.1),
    baseline_window=(0., 0.1),
    return_df=True)
# Merge in word metadata
word_offset_epochs = pd.merge(
    word_offset_epochs, word_metadata,
    left_on=["epoch_label", "instance_idx"],
    right_on=["label", "instance_idx"],
    how="left",
    validate="many_to_one")

word_offset_epochs.head(3)

In [None]:
syllable_epochs = epoch_by_state_space(
    aligned, "syllable",
    epoch_window=(-0.1, 0.3),
    baseline_window=(-0.1, 0.),
    return_df=True)

### Plotting prep

In [None]:
model_color_norm = plt.Normalize(0, len(encoder_names))
model_color_mapper = plt.colormaps["tab10"]
get_model_color = lambda model_name: model_color_mapper(model_color_norm(encoder_names.index(model_name)))

### Correspondences between electrodes significant under different models

In [None]:
log_pvals = ttest_filtered_df.pivot_table(values="log_pval", index="model2", columns="output_dim").fillna(0)
log_pvals

In [None]:
sns.clustermap(log_pvals, vmax=0, xticklabels=1, figsize=(14, 8))

## Colocation

### Colocation of baseline predictiveness and model improvement

In [None]:
score_relationship = scores_df.assign(model=scores_df.model.replace({model_name: "full_model" for model_name in set(scores_df.model) - {"baseline"}})) \
    .reset_index().pivot(index=["model2", "output_dim", "fold"], columns="model", values="score")
score_relationship

In [None]:
g = sns.lmplot(data=score_relationship.reset_index(), x="baseline", y="full_model", col="model2", col_wrap=3,
               facet_kws=dict(sharex=False, sharey=False))

ax_min = 0.
ax_max = score_relationship.max().max()
for ax in g.axes.ravel():
    ax.plot([0, 1], [0, 1], color="black", linestyle="--", alpha=0.4)
    ax.set_xlim(ax_min, ax_max)
    ax.set_ylim(ax_min, ax_max)
    ax.set_xlabel("Baseline encoder $r^2$")
    ax.set_ylabel("Full model $r^2$")

### Colocation of model embedding and baseline predictivity

In [None]:
speech_responsive_threshold = 0.01
speech_responsive_electrodes = baseline_scores.groupby("output_dim").score.mean()
speech_responsive_electrodes = speech_responsive_electrodes[speech_responsive_electrodes >= speech_responsive_threshold].index
speech_responsive_electrodes

In [None]:
model_embedding_improvements = (score_relationship.full_model - score_relationship.baseline).unstack("model2").groupby("output_dim").mean()
# not interested in overfit electrodes
model_embedding_improvements[model_embedding_improvements < 0] = np.nan
model_embedding_improvements

In [None]:
baseline_feature_improvements = unique_variance.unstack("dropped_feature").groupby("output_dim").mean()
# not interested in overfit electrodes
baseline_feature_improvements[baseline_feature_improvements < 0] = np.nan
baseline_feature_improvements

In [None]:
sns.clustermap(pd.concat([model_embedding_improvements, baseline_feature_improvements], axis=1)
                 .corr().loc[baseline_feature_improvements.columns, model_embedding_improvements.columns])

In [None]:
all_improvements = pd.merge(baseline_feature_improvements, model_embedding_improvements,
                            left_index=True, right_index=True, how="left", validate="one_to_one")

In [None]:
all_improvements

In [None]:
# represent improvement within each model as % of maximum unique variance
all_improvements_relative = all_improvements.apply(lambda xs: xs / (xs.max() - xs.min()), axis=0)
all_improvements_relative

In [None]:
all_improvements_relative.sort_values("word_broad-aniso2-w2v2_8", ascending=False)

In [None]:
(all_improvements["word_broad-aniso2-w2v2_8"].fillna(0) - all_improvements[["biphone_pred", "biphone_recon", "phoneme"]].max(axis=1).fillna(0)).dropna().sort_values(ascending=False)

### Colocation study by $p$-value

In [None]:
# get least-significant p-value result per model -- electrode
electrode_pvals = ttest_df.loc[(slice(None), "baseline"), "log_pval"].groupby(["model2", "output_dim"]).max()
# insert zero pvals for missing model--electrode combinations
electrode_pvals = electrode_pvals.reindex(pd.MultiIndex.from_product([study_models, electrode_names.index], names=["model2", "output_dim"])) \
    .fillna(0.)
electrode_pvals

In [None]:
contrast_focus = "word_broad-aniso2-w2v2_8"
contrast_negatives = ["phoneme", "biphone_pred", "biphone_recon"]
word_contrasts = electrode_pvals.groupby("output_dim").apply(
    lambda xs: xs.loc[contrast_focus] - xs.loc[contrast_negatives].min()).sort_values(ascending=True)
word_contrasts = word_contrasts.rename("word_contrast").to_frame().droplevel(-1)
word_contrasts = pd.merge(word_contrasts, electrode_pvals.loc[contrast_focus], left_index=True, right_index=True)
word_contrasts

In [None]:
plot_greatest_contrast = word_contrasts.head(6).index.get_level_values(0)
plot_greatest_contrast_df = coef_df.loc[contrast_focus]
plot_greatest_contrast_df = plot_greatest_contrast_df[plot_greatest_contrast_df.output_dim.isin(plot_greatest_contrast)]
plot_greatest_contrast_df = plot_greatest_contrast_df[plot_greatest_contrast_df.feature.str.startswith("model_embedding")]

g = sns.relplot(data=plot_greatest_contrast_df,
                col="output_dim", col_wrap=2, col_order=plot_greatest_contrast,
                x="time", y="coef", hue="feature", kind="line", errorbar="se",
                facet_kws=dict(sharex=False))
for ax in g.axes.ravel():
    ax.axhline(0, color="gray", linestyle="--", alpha=0.5)

In [None]:
def render_electrode_panel(
        electrode, model_embeddings=None, features=None,
        trial_epoch_kwargs=None,
        word_epoch_kwargs=None,
        word_epoch2_kwargs=None,
        smoke_test=False):
    figure = plt.figure(figsize=(32, 24) if not smoke_test else (10, 8))
    gs = gridspec.GridSpec(3, 4, figure=figure,
                           width_ratios=[3, 3, 2, 2], hspace=0.25, wspace=0.25)
    electrodes = [electrode]

    errorbar = "se" if not smoke_test else None

    if model_embeddings is None:
        model_embeddings = sorted([m for m in electrode_pvals.index.get_level_values("model2").unique() if m != "baseline"])
    if features is None:
        features = sorted([f for f in coef_df.feature.unique() if not f.startswith("model_embedding")])

    ##### plot electrode t-values and feature norms

    tval_ax = figure.add_subplot(gs[0, 0])
    tval_ax.set_title("Improvement log $p$-values by model embedding")
    tval_ax.axvline(np.log10(pval_threshold), color="black", linestyle="--", linewidth=2)
    feature_norm_ax = figure.add_subplot(gs[0, 1])
    feature_norm_ax.set_title("Unique variance")

    tval_df = electrode_pvals.loc[model_embeddings].loc[(slice(None), electrodes)]
    tval_df_order = tval_df.sort_values(ascending=True).index.get_level_values("model2")
    sns.barplot(data=tval_df.reset_index(), x="log_pval", y="model2",
                ax=tval_ax, order=tval_df_order)
    for ticklabel in tval_ax.get_yticklabels():
        if ticklabel.get_text() in model_embeddings:
            ticklabel.set_fontweight("bold")

    unique_variance_df = unique_variance.loc[(slice(None), electrodes)].reset_index().rename(columns={"dropped_feature": "feature"})
    unique_variance_means = unique_variance_df.groupby("feature").unique_variance_score.mean()
    unique_variance_df_order = unique_variance_means[unique_variance_means >= 0].sort_values(ascending=False).index
    sns.barplot(data=unique_variance_df,
                x="unique_variance_score", y="feature",
                ax=feature_norm_ax, order=unique_variance_df_order)
    feature_norm_ax.set_xlim((0, feature_norm_ax.get_xlim()[1]))
    for ticklabel in feature_norm_ax.get_yticklabels():
        if ticklabel.get_text().startswith(tuple(features)):
            ticklabel.set_fontweight("bold")

    #####

    # prepare single coefficient df
    plot_coef_df = coef_df.loc[model_embeddings].reset_index()
    # name model embedding coefficients according to model
    model_coefs = plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding")]
    plot_coef_df.loc[plot_coef_df.feature.str.startswith("model_embedding"), "feature"] = \
        model_coefs.model.str.cat(model_coefs.feature, sep="_")

    # filter to electrodes of interest
    plot_coef_df = plot_coef_df[plot_coef_df.output_dim.isin(electrodes)]
    # filter to features of interest
    plot_coef_df_features = plot_coef_df[plot_coef_df.feature.str.startswith(tuple(features))]
    plot_coef_df_features = plot_coef_df_features[["fold", "feature", "output_dim", "time", "coef"]]
    plot_coef_df_features["type"] = "basic_feature"
    # add computed feature norms for embeddings
    plot_coef_df_embeddings = plot_coef_df[plot_coef_df.feature.str.contains("model_embedding")]
    plot_coef_df_embeddings = plot_coef_df_embeddings.groupby(["fold", "model", "output_dim", "time"]) \
        .coef.apply(lambda xs: xs.abs().sum()).reset_index() \
        .rename(columns={"model": "feature"}).assign(type="model_embedding")
    
    #####
    # coef_line_ax = figure.add_subplot(gs[1, :])
    # sns.lineplot(data=plot_coef_subset_df, x="time", y="coef", hue="feature", style="type", ax=coef_line_ax)

    #####

    feature_coef_heatmap_ax = figure.add_subplot(gs[1, :2])
    plot_coef_heatmap_df = plot_coef_df_features.pivot_table(
        index="feature", columns="time", values="coef", aggfunc="mean")
    plot_coef_heatmap_df = plot_coef_heatmap_df.loc[sorted(plot_coef_df_features.feature.unique())]
    sns.heatmap(plot_coef_heatmap_df, ax=feature_coef_heatmap_ax, cmap="RdBu", center=0, yticklabels=True)

    model_coef_heatmap_ax = figure.add_subplot(gs[2, :2])
    plot_coef_heatmap_df = plot_coef_df_embeddings.pivot_table(
        index="feature", columns="time", values="coef", aggfunc="mean")
    # # order by decreasing t-value
    # plot_coef_heatmap_df = plot_coef_heatmap_df.loc[[model for model in tval_df_order if model in plot_coef_heatmap_df.index]]
    # order by name
    plot_coef_heatmap_df = plot_coef_heatmap_df.loc[sorted(plot_coef_df_embeddings.feature.unique())]
    sns.heatmap(plot_coef_heatmap_df, ax=model_coef_heatmap_ax, cmap="RdBu", center=0, yticklabels=True)

    #####

    trial_epochs_ax = figure.add_subplot(gs[0, 2])
    trial_epochs_ax.set_title("Trial ERP")
    trial_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_trial_epochs = trial_epochs[(trial_epochs.electrode_idx == electrode)]
    sns.lineplot(data=plot_trial_epochs, x="epoch_time", y="value", ax=trial_epochs_ax,
                 errorbar=errorbar,
                 **(trial_epoch_kwargs or {}))

    word_epochs_ax = figure.add_subplot(gs[1, 2])
    word_epochs_ax.set_title("Word ERP")
    word_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_word_epochs = word_epochs[word_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_epochs, x="epoch_time", y="value", ax=word_epochs_ax,
                 errorbar=errorbar,
                 **(word_epoch_kwargs or {}))
    
    word_epochs2_ax = figure.add_subplot(gs[2, 2])
    word_epochs2_ax.set_title("Word ERP")
    word_epochs2_ax.axvline(0, color="gray", linestyle="--")
    plot_word_epochs2 = word_epochs[word_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_epochs2, x="epoch_time", y="value", ax=word_epochs2_ax,
                 errorbar=errorbar,
                 **(word_epoch2_kwargs or {}))

    syllable_epochs_ax = figure.add_subplot(gs[0, 3])
    syllable_epochs_ax.set_title("Syllable ERP")
    syllable_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_syllable_epochs = syllable_epochs[syllable_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_syllable_epochs, x="epoch_time", y="value", ax=syllable_epochs_ax,
                 errorbar=errorbar)

    word_offset_epochs_ax = figure.add_subplot(gs[1, 3])
    word_offset_epochs_ax.set_title("Word offset ERP")
    word_offset_epochs_ax.axvline(0, color="gray", linestyle="--")
    plot_word_offset_epochs = word_offset_epochs[word_offset_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_offset_epochs, x="epoch_time", y="value", ax=word_offset_epochs_ax,
                 errorbar=errorbar, **(word_epoch_kwargs or {}))

    word_offset_epoch2_ax = figure.add_subplot(gs[2, 3])
    word_offset_epoch2_ax.set_title("Word offset ERP")
    word_offset_epoch2_ax.axvline(0, color="gray", linestyle="--")
    plot_word_offset_epoch2 = word_offset_epochs[word_offset_epochs.electrode_idx == electrode]
    sns.lineplot(data=plot_word_offset_epoch2, x="epoch_time", y="value", ax=word_offset_epoch2_ax,
                 errorbar=errorbar, **(word_epoch2_kwargs or {}))

    plt.suptitle(f"Electrode {electrode} study")
    
    return plot_coef_df

In [None]:
# electrodes showing greater response to word_broad than to biphone and phoneme features
panel_electrodes = {
    "word_dominant": word_contrasts[(word_contrasts.word_contrast <= -1) & (word_contrasts.log_pval <= -2)].index.tolist(),
    "phone_dominant": word_contrasts[(word_contrasts.word_contrast >= 1) & (word_contrasts.log_pval <= -2)].index.tolist(),
    "balanced": word_contrasts[(word_contrasts.word_contrast.between(-0.5, 0.5)) & (word_contrasts.log_pval <= -2)].index.tolist(),
}

# # electrodes showing balanced response between word_broad and biphone/phoneme features
# panel_electrodes += [200, 204, 221, 173, 199, 123, 257, 314, 172, 234, 331, 211]

# electrodes tuned to matched features and different models
# panel_electrodes = [231, 205, 373, 173, 214, 123]
# electrodes tuned to F0 but not to models
# panel_electrodes += [212, 33, 219, 193]

In [None]:
panel_electrodes

In [None]:
for group, electrode in tqdm([(group, electrode) for group, electrodes in panel_electrodes.items() for electrode in electrodes]):
    with plt.rc_context(rc={"font.size": 24}):
        render_electrode_panel(
            electrode,
            model_embeddings=study_models,
            word_epoch_kwargs=dict(hue="monosyllabic"),
            word_epoch2_kwargs=dict(hue="word_frequency_quantile"))
        f = plt.gcf()
        f.savefig(f"{output_dir}/electrode_panel-{subject}-{group}-{electrode}.png")
        plt.close()

In [None]:
baseline_mean_scores = baseline_scores.groupby("output_dim").score.mean()
ax = sns.swarmplot(baseline_mean_scores, color="gray")
ax.axhline(0, color="gray", linestyle="--")

all_panel_electrodes = list(itertools.chain.from_iterable(panel_electrodes.values()))
for elec, score in baseline_mean_scores.loc[all_panel_electrodes].items():
    ax.text(0.2 + np.random.normal(0, 0.1), score, elec, ha="center", va="bottom",
            transform=transforms.blended_transform_factory(ax.transAxes, ax.transData))
    ax.axhline(score, color="blue", linestyle="--", alpha=0.3)

ax.set_title("Electrode baseline performance")
ax.set_ylabel("Baseline $r^2$")