In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import ast
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

In [21]:
pval_threshold = 1e-3
output_dir = "."

encoder_comparisons = list(Path("outputs/encoder_comparison").glob("*/*/*/*"))

In [22]:
output_dir = Path(output_dir)

In [23]:
encoder_comparison_names = ["dataset", "subject", "model2", "model1"]
encoder_comparisons = [Path(p) for p in encoder_comparisons]
encoder_comparison_keys = [(path.parts[-4], path.parts[-3], path.parts[-2], path.parts[-1]) for path in encoder_comparisons]

In [None]:
for p in encoder_comparisons:
    df = pd.read_csv(p / "ttest_results.csv")
    if len(df.columns) != 4:
        print(p, df.columns)

In [None]:
all_ttest_results = [pd.read_csv(path / "ttest_results.csv")
                     for path in encoder_comparisons]
keep_result = [len(df) > 0 for df in all_ttest_results]
all_ttest_results = [df for df, keep in zip(all_ttest_results, keep_result) if keep]
all_ttest_keys = [key for key, keep in zip(encoder_comparison_keys, keep_result) if keep]

all_ttest_results = pd.concat(
    all_ttest_results,
    names=encoder_comparison_names,
    keys=all_ttest_keys)
all_ttest_results["output_dim"] = all_ttest_results.output_dim.astype(int)
all_ttest_results.to_csv(output_dir / "ttest.csv")
all_ttest_results

In [None]:
all_scores_df = pd.concat([
        # whoops, output_name is a tuple -- let's read it as such
        pd.read_csv(path / "scores.csv", converters={"output_name": ast.literal_eval}) for path in encoder_comparisons
    ], keys=encoder_comparison_keys, names=encoder_comparison_names) \
    .reset_index(level=-1, drop=True)
all_scores_df["electrode_name"] = all_scores_df.output_name.str[1]
all_scores_df = all_scores_df.set_index("electrode_name", append=True)
all_scores_df.to_csv(output_dir / "scores.csv")
all_scores_df

In [None]:
all_electrodes_keys = [(dataset, subject, model2) for dataset, subject, model2, model1 in encoder_comparison_keys]
all_electrodes_paths = [Path("outputs/encoders") / dataset / model2 / subject / "electrodes.csv"
                        for dataset, subject, model2 in all_electrodes_keys]
all_electrodes_df = pd.concat([
        pd.read_csv(path) for path in all_electrodes_paths
    ], keys=all_electrodes_keys, names=["dataset", "subject", "model"]) \
    .droplevel(-1)
all_electrodes_df.to_csv(output_dir / "electrodes.csv")
all_electrodes_df

In [None]:
covered_models = sorted(all_ttest_results.index.get_level_values("model2").unique())
covered_subjects = sorted(all_ttest_results.index.get_level_values("subject").unique())
coverage_df = pd.DataFrame(np.zeros((len(covered_subjects), len(covered_models)), dtype=int),
                           columns=pd.Index(covered_models, name="model"),
                            index=pd.Index(covered_subjects, name="subject"))
for dataset, subject, model2, model1 in encoder_comparison_keys:
    try:
        coverage_df.loc[subject, model2] += 1
    except KeyError:
        print(f"Missing all data for {subject}?")
        pass
f, ax = plt.subplots(figsize=(8, 4))
sns.heatmap(coverage_df, fmt=",d")

In [None]:
# # Merge in electrode information
# all_scores_df = pd.merge(
#     all_scores_df,
#     all_electrodes_df.rename(columns=lambda col: f"electrode_{col}" if not col.startswith("electrode") else col),
#     left_index=True, right_index=True,
#     how="left", validate="many_to_one")

In [None]:
# all_scores_df.to_csv(Path(output_dir) / "all_encoding_scores.csv")
# all_electrodes_df.to_csv(Path(output_dir) / "all_electrodes.csv")

## Electrode selection

In [None]:
# For each dataset-subject-model1-model2-electrode, take the permutation which yields the LARGEST pvalue
# This is a more stringent test of improvement
ttest_results_filtered = all_ttest_results.groupby(["dataset", "subject", "model2", "model1", "output_dim"]).apply(lambda df: df.loc[df.pval.idxmax()])
ttest_results_filtered = ttest_results_filtered[(ttest_results_filtered.tval > 0) & (ttest_results_filtered.pval < pval_threshold)] \
    .sort_values("pval")
ttest_results_filtered.to_csv(Path(output_dir) / "ttest_filtered.csv")
ttest_results_filtered

## Summary quantitative analysis

In [None]:
all_scores_df

In [None]:
# # plot baseline performance
# baseline_scores = all_scores_df.xs("baseline", level="model")
# plot_df = baseline_scores.groupby(["subject", "electrode_roi", "electrode_name"]).score.mean().reset_index()

# f, ax = plt.subplots(figsize=(15, 8))
# sns.barplot(data=plot_df, x="subject", y="score", ax=ax)
# ax.set_title("Mean baseline r^2 by subject, across all electrodes")

In [None]:
# # plot baseline performance
# baseline_scores = all_scores_df.xs("baseline", level="model")
# plot_df = baseline_scores.groupby(["subject", "fold"]).score.max().reset_index()

# f, ax = plt.subplots(figsize=(15, 8))
# sns.barplot(data=plot_df, x="subject", y="score", ax=ax)
# ax.set_title("Max baseline r^2 by subject, across all electrodes")

In [None]:
# def compute_improvement_within_fold(fold_df):
#     print(fold_df.score)
#     ret = fold_df.score - fold_df.xs("baseline", level="model").score
#     print(ret)
#     return ret

# # all_scores_df.groupby(["subject", "output_name"]).apply(compute_improvement_within_fold)
# improvement_df = all_scores_df.set_index(["fold", "output_name"], append=True)
# improvement_df = pd.merge(improvement_df, (improvement_df.score - improvement_df.xs("baseline", level="model").score).rename("improvement"),
#                           left_index=True, right_index=True)
# improvement_df = improvement_df.loc[improvement_df.index.get_level_values("model") != "baseline"]
# improvement_df

In [None]:
# improvement_df.groupby(["subject", "output_name", "model"]).improvement.mean().sort_values(ascending=False)

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# plot_df = improvement_df.groupby(["subject", "model", "output_name"]).improvement.mean().reset_index()
# sns.boxplot(data=plot_df, order=plot_df.groupby("model").improvement.mean().sort_values(ascending=False).index,
#             x="model", y="improvement", hue="subject", ax=ax)

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# plot_df = improvement_df.groupby(["subject", "model", "output_name"]).improvement.mean().reset_index()
# sns.barplot(data=plot_df, order=plot_df.groupby("subject").improvement.mean().sort_values(ascending=False).index,
#             x="subject", y="improvement", hue="model", ax=ax)
# ax.set_title("Mean improvement across electrodes within subject and model")

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# plot_df = improvement_df.groupby(["subject", "model", "fold"]).improvement.max().reset_index()
# sns.barplot(data=plot_df, order=plot_df.groupby("subject").improvement.mean().sort_values(ascending=False).index,
#             x="subject", y="improvement", hue="model", ax=ax)
# ax.set_title("Max improvement across electrodes within subject and model")

In [None]:
plot_df = all_electrodes_df.roi.value_counts()
plot_df = plot_df[plot_df / plot_df.sum() >= 0.01]
ax = sns.barplot(data=plot_df)
ax.set_title("Number of electrodes per ROI")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
None

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# sns.barplot(data=improvement_df.reset_index(),
#             x="model", y="improvement", hue="electrode_roi", ax=ax)
# ax.set_title("Mean improvement across subject, electrode within ROI and model")

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# sns.barplot(data=improvement_df.reset_index(),
#             order=improvement_df.reset_index().groupby("electrode_roi").improvement.mean().sort_values(ascending=False).index,
#             x="electrode_roi", y="improvement", hue="model",
#             ax=ax)
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
# ax.set_title("Mean improvement across subject, electrode within ROI and model")

In [None]:
# f, ax = plt.subplots(figsize=(15, 8))
# plot_df = improvement_df.groupby(["subject", "electrode_roi", "model", "fold"]).improvement.max().reset_index()
# sns.barplot(data=plot_df,
#             order=plot_df.groupby("electrode_roi").improvement.mean().sort_values(ascending=False).index,
#             x="electrode_roi", y="improvement", hue="model",
#             ax=ax)
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
# ax.set_title("Max improvement across subject, electrode within ROI and model")