In [None]:
from functools import partial
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import ttest_1samp
import torch

from src.encoding.ecog.timit import trf_grid_to_df

In [None]:
subject = "EC270"
dataset = "timit-no_repeats"

model1 = "baseline"
model2 = "biphone_recon"

model1_scores_path = f"outputs/encoders/{dataset}/{model1}/{subject}/scores.csv"
model2_scores_path = f"outputs/encoders/{dataset}/{model2}/{subject}/scores.csv"
model1_coefs_path = f"outputs/encoders/{dataset}/{model1}/{subject}/coefs.pkl"
model2_coefs_path = f"outputs/encoders/{dataset}/{model2}/{subject}/coefs.pkl"
model1_model_path = f"outputs/encoders/{dataset}/{model1}/{subject}/model.pkl"
model2_model_path = f"outputs/encoders/{dataset}/{model2}/{subject}/model.pkl"

model2_permutation_score_paths = {
    "units": [
        f"outputs/encoders-permute_units/0/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_units/1/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_units/2/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_units/3/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_units/4/{dataset}/{model2}/{subject}/scores.csv",
    ],
    "shift": [
        f"outputs/encoders-permute_shift/0/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_shift/1/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_shift/2/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_shift/3/{dataset}/{model2}/{subject}/scores.csv",
        f"outputs/encoders-permute_shift/4/{dataset}/{model2}/{subject}/scores.csv",
    ],
}

output_dir = "."

In [None]:
model1_scores = pd.read_csv(model1_scores_path)
model2_scores = pd.read_csv(model2_scores_path)

In [None]:
assert set(model1_scores.output_dim) == set(model2_scores.output_dim)

In [None]:
model2_permutation_scores = {
    permutation_name: pd.concat([
            pd.read_csv(permutation_scores_path)
            for permutation_scores_path in permutation_scores_paths
        ], names=["permutation_idx"], keys=range(len(permutation_scores_paths)))
    for permutation_name, permutation_scores_paths in model2_permutation_score_paths.items()
}
model2_permutation_scores = pd.concat(model2_permutation_scores, names=["permutation"]) \
    .droplevel(-1).set_index(["output_dim", "fold"], append=True)
model2_permutation_scores

In [None]:
assert set(model2_permutation_scores.index.get_level_values("output_dim")) == set(model1_scores.output_dim)

In [None]:
all_scores = pd.concat([model1_scores, model2_scores], names=["model"], keys=[model1, model2]) \
    .droplevel(-1) \
    .set_index(["output_dim", "fold"], append=True)
all_scores.to_csv(Path(output_dir) / "scores.csv")
all_scores

In [None]:
all_improvements = all_scores.loc[model2].score - all_scores.loc[model1].score
all_improvements.to_csv(Path(output_dir) / "improvements.csv")
all_improvements

In [None]:
permutation_improvements = pd.merge(model2_permutation_scores, all_scores.loc[model1], left_index=True, right_index=True, how="inner",
                                    suffixes=("_perm", "_model1"))
permutation_improvements = (permutation_improvements.score_perm - permutation_improvements.score_model1).rename("score")
permutation_improvements.to_csv(Path(output_dir) / "permutation_improvements.csv")
permutation_improvements

## Visualize electrode performance distribution

In [None]:
def plot_electrode_performance_distribution(score_data: pd.DataFrame, ax=None):
    if ax is None:
        f, ax = plt.subplots()

    if "electrode_group" in score_data.columns:
        sns.violinplot(x="electrode_group", y="score", data=score_data, ax=ax)
        sns.swarmplot(x="electrode_group", y="score", color="black",
                      alpha=0.5, data=score_data, ax=ax)
    else:
        sns.violinplot(data=score_data, y="score", ax=ax)
        
        sns.swarmplot(data=score_data, y="score", color="black",
                      alpha=0.5, ax=ax)

    ax.axhline(0, color="black", linestyle="--")
    return ax

### Baseline

In [None]:
baseline_mean_scores = all_scores.loc[model1].groupby("output_dim").score.mean()
plot_electrode_performance_distribution(baseline_mean_scores.to_frame())


In [None]:
speech_responsive_electrodes = baseline_mean_scores[baseline_mean_scores > 0.025].index
speech_responsive_electrodes

### Model improvements

In [None]:
swarm_data = all_improvements.groupby(["output_dim"]).mean().to_frame()
swarm_data["electrode_group"] = "na"
swarm_data.loc[speech_responsive_electrodes, "electrode_group"] = "speech responsive"

plot_electrode_performance_distribution(swarm_data.reset_index())

### Permutation improvements

In [None]:
swarm_data = permutation_improvements.groupby(["permutation", "output_dim"]).mean().to_frame()
swarm_data["electrode_group"] = "na"
swarm_data.loc[(slice(None), speech_responsive_electrodes), "electrode_group"] = "speech responsive"
swarm_data = swarm_data.reset_index()

f, axs = plt.subplots(swarm_data.permutation.nunique(), 1, figsize=(6, 5 * swarm_data.permutation.nunique()))
for permutation, ax in zip(swarm_data.permutation.unique(), axs):
    plot_electrode_performance_distribution(swarm_data[swarm_data.permutation == permutation], ax=ax)
    ax.set_title(permutation)

## Visualize contrast as heatmap

In [None]:
# plot as grid
num_output_dims = len(model1_scores.output_dim.unique())
grid_num_rows = np.ceil(np.sqrt(num_output_dims)).astype(int)
grid_num_cols = np.ceil(num_output_dims / grid_num_rows).astype(int)

def output_dim_to_grid_coords(output_dim):
    return grid_num_cols - output_dim // grid_num_cols - 1, grid_num_rows - output_dim % grid_num_rows - 1

electrode_grid = np.zeros((grid_num_rows, grid_num_cols)) * np.nan
for i in range(num_output_dims):
    x, y = output_dim_to_grid_coords(i)
    electrode_grid[y, x] = i

scores_grid = np.zeros((2, grid_num_rows, grid_num_cols)) * np.nan
for i, scores in enumerate([model1_scores, model2_scores]):
    for output_dim, scores_rows in scores.groupby("output_dim"):
        x, y = output_dim_to_grid_coords(output_dim)
        scores_grid[i, y, x] = scores_rows.score.mean()

scores_diff_grid = np.zeros((grid_num_rows, grid_num_cols)) * np.nan
for output_dim, scores_rows in all_improvements.groupby("output_dim"):
    x, y = output_dim_to_grid_coords(output_dim)
    scores_diff_grid[y, x] = scores_rows.mean()

permutation_scores_diff_grid = np.zeros((grid_num_rows, grid_num_cols)) * np.nan
for output_dim, scores_rows in permutation_improvements.groupby("output_dim"):
    x, y = output_dim_to_grid_coords(output_dim)
    permutation_scores_diff_grid[y, x] = scores_rows.mean()


f, axs = plt.subplots(5, 1, figsize=(6, 6 * 5))

# Sanity check: plot electrode IDs in grid form. Cross-check this with recon
sns.heatmap(electrode_grid, annot=True, fmt=".0f", ax=axs[0])
axs[0].set_title("Electrode IDs")

sns.heatmap(scores_grid[0], ax=axs[1])
axs[1].set_title(model1)

sns.heatmap(scores_grid[1], ax=axs[2])
axs[2].set_title(model2)

sns.heatmap(scores_diff_grid, ax=axs[3], center=0.0, cmap="RdBu")
axs[3].set_title(f"{model2} - {model1}")

sns.heatmap(permutation_scores_diff_grid, ax=axs[4], center=0.0, cmap="RdBu")
axs[4].set_title(f"Permutation {model2} - {model1}")

## Quantitative

In [None]:
positive_improvement = all_improvements.groupby("output_dim").mean() > 0
print(f"Electrodes showing numerical improvement over baseline: "
      f"{positive_improvement.sum()} ({positive_improvement.mean() * 100:.2f}%)")

In [None]:
study_improvements = all_improvements.loc[positive_improvement[positive_improvement].index]
study_permutation_improvements = pd.merge(permutation_improvements, positive_improvement[positive_improvement].rename("positive_improvement"),
         left_index=True, right_index=True, how="inner").score \
    .groupby(["output_dim", "fold", "permutation"]).mean()

In [None]:
# Most stringent picture: take the minimum difference between full model and ANY permuted model,
# marginalizing over permutation type
improvements_over_permutation = (study_improvements - study_permutation_improvements) \
    .groupby(["output_dim", "fold"]).min()

In [None]:
ax = sns.boxplot(improvements_over_permutation.groupby("output_dim").mean())
ax.axhline(0, color="k", linestyle="--")

In [None]:
ttest_improvements = all_improvements.loc[positive_improvement[positive_improvement].index]
ttest_permutation_improvements = pd.merge(permutation_improvements, positive_improvement[positive_improvement].rename("positive_improvement"),
         left_index=True, right_index=True, how="inner").score
if len(ttest_improvements) == 0:
    print("No electrodes showing improvement. Stop.")
    pd.DataFrame().to_csv(Path(output_dir) / "ttest_results.csv")
else:
    ttest_results = (ttest_improvements - ttest_permutation_improvements) \
        .groupby(["output_dim", "permutation", "permutation_idx"]).mean() \
        .groupby(["output_dim", "permutation"]).apply(lambda xs: pd.Series(ttest_1samp(xs, 0), index=["tval", "pval"])) \
        .unstack() \
        .sort_values("pval")
    ttest_results.to_csv(Path(output_dir) / "ttest_results.csv")

    ttest_grid = np.zeros((grid_num_rows, grid_num_cols))
    for output_dim, ttest_rows in ttest_results.groupby("output_dim"):
        x, y = output_dim_to_grid_coords(output_dim)
        ttest_grid[y, x] = np.nanmin(ttest_rows.tval)

    ax = sns.heatmap(ttest_grid)
    ax.set_title("ttest t-values")