Run model comparisons on individual electrodes in order to define qualitative "contrasts" for individual electrodes.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from functools import partial
import logging
from pathlib import Path
from typing import Literal

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm.auto import tqdm

In [None]:
L = logging.getLogger(__name__)

In [None]:
dataset = "timit-no_repeats"
study_models = {
    "Random": "random32-w2v2_8-l2norm",
    "Phoneme": "phoneme-w2v2_8-l2norm",
    "Word": "ph-ls-word_broad-hinge-w2v2_8-l2norm",
    "Word discrim2": "ph-ls-word_broad-hinge-w2v2_8-discrim2-l2norm",
}
ttest_results_path = f"outputs/encoder_comparison_across_subjects/{dataset}/ttest.csv"
scores_path = f"outputs/encoder_comparison_across_subjects/{dataset}/scores.csv"

encoder_dirs = list(Path("outputs/encoders").glob(f"{dataset}/*/*"))

pval_threshold = 1e-4
# pval_threshold = 5e-7

baseline_model = "baseline"

contrasts = {
    "word_dominant": (["Word"], ["Phoneme", "Random"]),
    "word_discrim_dominant": (["Word discrim2"], ["Phoneme", "Random"]),
    "phone_dominant": (["Phoneme"], ["Word", "Random"]),
    "random_dominant": (["Random"], ["Phoneme", "Word"]),
}

output_dir = "."

In [None]:
encoder_dirs = [Path(p) for p in encoder_dirs]
study_model_codes = list(study_models.values())

# map to codes
contrasts = {contrast_name: ([study_models[model_name] for model_name in positive_model_names],
                             [study_models[model_name] for model_name in negative_model_names])
             for contrast_name, (positive_model_names, negative_model_names) in contrasts.items()}

In [None]:
scores_df = pd.read_csv(scores_path, index_col=["dataset", "subject", "model2", "model1"]).loc[dataset]
if study_model_codes is None:
    study_model_codes = sorted(scores_df.index.get_level_values("model2").unique())
    study_models = {code: code for code in study_model_codes}
else:
    scores_df = scores_df.loc[scores_df.index.get_level_values("model2").isin(study_model_codes + [baseline_model])]

study_model_code_to_name = {code: name for name, code in study_models.items()}

In [None]:
ttest_df = pd.read_csv(ttest_results_path, index_col=["dataset", "subject", "model2", "model1", "output_dim"]) \
    .loc[dataset].loc[(slice(None), study_model_codes), :]
ttest_df["log_pval"] = np.log10(ttest_df["pval"])
ttest_df

In [None]:
electrode_df = pd.concat([pd.read_csv(encoder_dir / "electrodes.csv", index_col=["electrode_idx"])
                          for encoder_dir in encoder_dirs if "baseline" == encoder_dir.parent.name],
                         names=["subject"], keys=[encoder_dir.name for encoder_dir in encoder_dirs if "baseline" == encoder_dir.parent.name])
electrode_df

In [None]:
class PValueContrast:
    """
    Defines a qualitative model contrast based on p-value of ttest.
    """
    def __init__(self, ttest_df, scores_df, electrode_df,
                 study_model_codes,
                 pval_threshold=1e-4):
        self.ttest_df = ttest_df
        self.electrode_df = electrode_df
        self.study_model_codes = study_model_codes
        self.pval_threshold = pval_threshold

    def get_contrast_inputs(self):
        # get least-significant p-value result per model -- electrode
        electrode_pvals = self.ttest_df.loc[(slice(None), slice(None), "baseline"), "log_pval"].groupby(["model2", "subject", "output_dim"]).max()
        # insert zero pvals for missing model--electrode combinations
        electrode_pvals = electrode_pvals.reindex(pd.MultiIndex.from_tuples(
             [(model, subject, output_dim)
              for subject, output_dim in self.electrode_df.index
              for model in self.study_model_codes],
             names=["model2", "subject", "output_dim"])) \
                .fillna(0.)

        return electrode_pvals
    
    def get_contrast_outcome(self, inputs, positive_models, negative_models):
        outcomes = inputs.groupby(["subject", "output_dim"]).apply(
            lambda xs: xs.loc[positive_models].min() - xs.loc[negative_models].min()) \
            .sort_values(ascending=True) \
            .rename("contrast_value").to_frame()
        outcomes["positive_pval"] = inputs.loc[positive_models].groupby(["subject", "output_dim"]).min()

        # add qualitative label
        outcomes["outcome"] = None
        outcomes.loc[(outcomes["positive_pval"] < -np.log10(self.pval_threshold)) & (outcomes["contrast_value"] <= -1), "outcome"] = "positive"
        outcomes.loc[(outcomes["positive_pval"] < -np.log10(self.pval_threshold)) & (outcomes["contrast_value"] >= 1), "outcome"] = "negative"
        outcomes.loc[(outcomes["positive_pval"] < -np.log10(self.pval_threshold)) & (outcomes["contrast_value"].abs() <= 0.5), "outcome"] = "balanced"

        return outcomes
    

class R2Contrast:
    """
    Defines a qualitative model contrast based on relative R2 improvement.
    """
    def __init__(self, ttest_df, scores_df, electrode_df,
                 study_model_codes,
                 mode: Literal["relative", "absolute"] = "absolute",
                 r2_threshold=0.1,
                 r2_contrast_threshold=0.1):
        self.ttest_df = ttest_df
        self.scores_df = scores_df
        self.electrode_df = electrode_df
        self.study_model_codes = study_model_codes

        self.mode = mode
        self.r2_threshold = r2_threshold
        self.r2_contrast_threshold = r2_contrast_threshold

    def get_contrast_inputs(self):
        r2_comparison = self.scores_df.xs(baseline_model, level="model1")
        # r2_comparison = r2_comparison.groupby(["subject", "model2", "model", "output_dim"]).score.mean().reset_index()
        r2_comparison.loc[r2_comparison.model != baseline_model, "model"] = "full_model"
        r2_comparison = r2_comparison.reset_index().pivot_table(index=["subject", "model2", "output_dim", "fold"], columns="model", values="score")

        # avoid using negative values as baseline reference
        baseline_reference = r2_comparison[baseline_model]
        baseline_relative_reference = baseline_reference[baseline_reference > 0]
        r2_comparison["absolute_improvement"] = r2_comparison["full_model"] - baseline_reference.combine(0, max)
        r2_comparison["relative_improvement"] = r2_comparison["absolute_improvement"] / baseline_relative_reference

        # mean across folds
        r2_comparison = r2_comparison.groupby(["subject", "model2", "output_dim"]).mean()
        
        r2_comparison = r2_comparison.reorder_levels(["model2", "subject", "output_dim"])

        return r2_comparison
    
    def get_contrast_outcome(self, inputs, positive_models, negative_models):
        # # compare the minimum relative improvement of positive model set
        # # to the maximum relative improvement of negative model set
        # outcomes = inputs.relative_improvement.groupby(["subject", "output_dim"]).apply(
        #     lambda xs: xs.loc[positive_models].min() - max(0, xs.loc[negative_models].max())) \
        #     .sort_values(ascending=False) \
        #     .rename("contrast_value").to_frame()
        
        # NB most stringent test -- we take the MAXIMUM improvement of the negative models
        # and the MINIMUM improvement of the positive models
        outcomes = pd.DataFrame({
            "positive_r2_relative_improvement": inputs.loc[positive_models, "relative_improvement"].groupby(["subject", "output_dim"]).min(),
            "positive_r2_absolute_improvement": inputs.loc[positive_models, "absolute_improvement"].groupby(["subject", "output_dim"]).min(),
            "positive_r2_absolute": inputs.loc[positive_models, "full_model"].groupby(["subject", "output_dim"]).min(),

            "negative_r2_relative_improvement": inputs.loc[negative_models, "relative_improvement"].groupby(["subject", "output_dim"]).max(),
            "negative_r2_absolute_improvement": inputs.loc[negative_models, "absolute_improvement"].groupby(["subject", "output_dim"]).max(),
            "negative_r2_absolute": inputs.loc[negative_models, "full_model"].groupby(["subject", "output_dim"]).max(),
        })

        if self.mode == "relative":
            outcomes["contrast_value"] = outcomes["positive_r2_relative_improvement"] - outcomes["negative_r2_relative_improvement"].combine(0, max)
        elif self.mode == "absolute":
            outcomes["contrast_value"] = outcomes["positive_r2_absolute"] - outcomes["negative_r2_absolute"].combine(0, max)
        # exclude overfit models
        outcomes.loc[outcomes["positive_r2_absolute"] < 0, "contrast_value"] = 0

        outcomes = outcomes.sort_values("contrast_value", ascending=False)

        # add qualitative label
        outcomes["outcome"] = None

        if self.mode == "relative":
            outcomes.loc[(outcomes["positive_r2_relative_improvement"] > self.r2_threshold)
                        & (outcomes["contrast_value"] > -self.r2_contrast_threshold), "outcome"] = "positive"
            outcomes.loc[(outcomes["positive_r2_relative_improvement"] > self.r2_threshold)
                        & (outcomes["contrast_value"] <= self.r2_contrast_threshold), "outcome"] = "negative"
            outcomes.loc[(outcomes["positive_r2_relative_improvement"] > self.r2_threshold)
                        & (outcomes["contrast_value"].abs() <= self.r2_contrast_threshold), "outcome"] = "balanced"
        elif self.mode == "absolute":
            outcomes.loc[(outcomes["positive_r2_absolute"] > self.r2_threshold)
                        & (outcomes["contrast_value"] > -self.r2_contrast_threshold), "outcome"] = "positive"
            outcomes.loc[(outcomes["positive_r2_absolute"] > self.r2_threshold)
                        & (outcomes["contrast_value"] <= self.r2_contrast_threshold), "outcome"] = "negative"
            outcomes.loc[(outcomes["positive_r2_absolute"] > self.r2_threshold)
                        & (outcomes["contrast_value"].abs() <= self.r2_contrast_threshold), "outcome"] = "balanced"
        # exclude overfit models
        outcomes.loc[outcomes["positive_r2_absolute"] < 0, "outcome"] = None

        return outcomes


CONTRAST_METHODS = {
    "pval": PValueContrast,
    # "relative_r2_10": partial(RelativeR2Contrast, relative_r2_threshold=0.1, relative_r2_contrast_threshold=0.1),
    "absolute_r2_1e-3": partial(R2Contrast, r2_threshold=1e-3, r2_contrast_threshold=1e-3),
}

In [None]:
contrast_outcomes = {}
for contrast_method, contraster in CONTRAST_METHODS.items():
    for contrast_name, (positive_models, negative_models) in contrasts.items():
        contraster = CONTRAST_METHODS[contrast_method](ttest_df, scores_df, electrode_df,
                                                    study_model_codes=study_model_codes)
        contrast_inputs = contraster.get_contrast_inputs()
        assert contrast_inputs.index.names == ["model2", "subject", "output_dim"], \
            f"Unexpected index names: {contrast_inputs.index.names}"

        positive_models_ = set(positive_models) & set(study_model_codes)
        negative_models_ = set(negative_models) & set(study_model_codes)
        if not positive_models or not negative_models_:
            raise ValueError("Missing all negative models or all positive models")
        if positive_models_ != set(positive_models):
            L.warning("Missing some positive models: %s", set(positive_models) - positive_models_)
        if negative_models_ != set(negative_models):
            L.warning("Missing some negative models: %s", set(negative_models) - negative_models_)
        positive_models_ = list(positive_models_)
        negative_models_ = list(negative_models_)

        contrast_outcomes[contrast_method, contrast_name] = contraster.get_contrast_outcome(
            contrast_inputs, positive_models_, negative_models_)
    
contrast_outcomes_df = pd.concat(contrast_outcomes, names=["contrast_method", "contrast"])
contrast_outcomes_df

In [None]:
if len(contrast_outcomes_df.index.get_level_values("contrast_method").unique()) == 2:
    cm1, cm2 = contrast_outcomes_df.index.get_level_values("contrast_method").unique()
    contrast_confusion = contrast_outcomes_df.reset_index().pivot(index=["contrast", "subject", "output_dim"], columns=["contrast_method"], values="outcome") \
        .groupby(["contrast"]).apply(lambda xs: pd.crosstab(xs[cm1], xs[cm2]))
    print(contrast_confusion)

    sns.heatmap(contrast_confusion)

In [None]:
hm = contrast_outcomes_df.reset_index().pivot(index=["contrast", "subject", "output_dim"], columns=["contrast_method"], values="outcome")
hm = hm.applymap({"negative": -1, "balanced": 1, "positive": 2}.get).fillna(0.)
sns.heatmap(hm)

In [None]:
contrast_outcomes_pivot = contrast_outcomes_df.reset_index().pivot(
    index=["subject", "output_dim"], columns=["contrast_method", "contrast"],
    values=["outcome", "contrast_value"])
contrast_outcomes_pivot

In [None]:
outcomes_to_plot = contrast_outcomes_pivot["contrast_value"].dropna().astype(float)
outcomes_to_plot = outcomes_to_plot.loc[~(outcomes_to_plot == 0).all(axis=1)]
# normalize within measure
outcomes_to_plot = outcomes_to_plot.stack().apply(lambda xs: (xs - xs.mean()) / xs.std()).unstack()
sns.clustermap(outcomes_to_plot, col_cluster=False, metric="cosine")

In [None]:
contrast_outcomes_df.to_csv(Path(output_dir) / "contrasts.csv")