In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
import numpy as np
from omegaconf import OmegaConf
import pandas as pd
from scipy.io import loadmat
import seaborn as sns
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm, trange

from src.encoding.ecog import timit as timit_encoding
from src.encoding.ecog import get_electrode_df
from src.estimate_encoder import prepare_xy

In [None]:
encoder_path = "outputs/encoders/timit/baseline/EC212"
output_dir = "."

In [None]:
with initialize_config_dir(str(Path(encoder_path).resolve() / ".hydra")):
    config = compose(config_name="config")

    print(OmegaConf.to_yaml(config))

In [None]:
baseline_scores = pd.read_csv(Path(encoder_path) / "scores.csv")

In [None]:
# All data should be from the same subject
all_subjects = set(data_spec.subject for data_spec in config.data)
assert len(all_subjects) == 1, f"All data should be from the same subject. Got: {all_subjects}"
subject = all_subjects.pop()

# Prepare electrode metadata
electrode_df = get_electrode_df(config, subject)

all_xy = [prepare_xy(config, data_spec) for data_spec in config.data]
X, Y, feature_names, feature_shapes, trial_onsets = timit_encoding.concat_xy(all_xy)

cv_outer = instantiate(config.cv)
cv_inner = instantiate(config.cv)

# TODO check match between model sfreq and dataset sfreq

# Prepare MNE model kwargs from config
if "model" in config:
    trf_kwargs = OmegaConf.to_object(config.model)
else:
    # account for legacy config
    trf_kwargs = {
        "type": "trf",
        "tmin": 0.0,
        "tmax": 0.6,
        "sfreq": 100,
        "fit_intercept": False,
    }

sfreq = trf_kwargs.pop("sfreq")
trf_kwargs.pop("type")

In [None]:
def estimate_unique_variance(X, Y, target_feature_set, feature_sets,
                             feature_names, feature_shapes,
                             cv_outer, cv_inner, sfreq, **trf_kwargs):
    # NB we are masking out an entire "feature set", which corresponds to 
    # one or more individual features in the TRF.
    # The names of the individual features are formatted as "{feature_set}_{idx}"
    assert target_feature_set in config.feature_sets.baseline_features

    # Get start and end indices of columns corresponding to feature sets
    feature_set_indices = np.cumsum([0] + feature_shapes)
    feature_set_start_idxs = feature_set_indices[:-1]
    feature_set_end_idxs = feature_set_indices[1:]
    assert len(feature_set_start_idxs) == len(feature_sets)

    # Prepare to mask out the target feature
    feature_mask = np.ones(X.shape[1], dtype=bool)
    feature_set_idx = feature_sets.index(target_feature_set)
    feature_mask[feature_set_start_idxs[feature_set_idx]:feature_set_end_idxs[feature_set_idx]] = False

    feature_names_masked = [name for name in feature_names if not name.startswith(target_feature_set)]
    feature_shapes_masked = [shape for name, shape in zip(feature_sets, feature_shapes) if not name.startswith(target_feature_set)]
    assert sum(feature_shapes_masked) == feature_mask.sum()

    _, _, scores, _, _ = timit_encoding.strf_nested_cv(
        X[:, feature_mask], Y, feature_names_masked, feature_shapes_masked,
        sfreq=sfreq, cv_outer=cv_outer, cv_inner=cv_inner, trf_kwargs=trf_kwargs
    )

    if len(scores) == 0:
        # No models converged. Save dummy outputs.
        feature_scores_df = pd.DataFrame(
            [(fold, output_dim, np.nan)
             for fold in range(cv_outer.get_n_splits())
             for output_dim in range(Y.shape[1])],
            columns=["fold", "output_dim", "score"]
        )
    else:
        feature_scores_df = pd.DataFrame(
            np.array(scores),
            index=pd.Index(list(range(cv_outer.get_n_splits())), name="fold"),
            columns=pd.Index(list(range(scores[0].shape[0])), name="output_dim"))
        feature_scores_df = feature_scores_df.reset_index().melt(id_vars="fold", var_name="output_dim", value_name="score")
    feature_scores_df["dropped_feature"] = target_feature_set
    feature_scores_df["output_name"] = feature_scores_df.output_dim.map(dict(enumerate(electrode_df.index)))

    return feature_scores_df

In [None]:
feature_score_dfs = []
for feature_set in tqdm(config.feature_sets.baseline_features, unit="feature"):
    feature_score_df = estimate_unique_variance(
        X, Y, feature_set, config.feature_sets.baseline_features,
        feature_names, feature_shapes,
        cv_outer, cv_inner, sfreq, **trf_kwargs)
    feature_score_dfs.append(feature_score_df)

In [None]:
merged_df = pd.concat([baseline_scores] + feature_score_dfs) \
    .set_index(["dropped_feature", "fold", "output_dim"])

In [None]:
final_df = pd.merge(merged_df, (merged_df.score - merged_df.loc[np.nan].score).rename("unique_variance_score").to_frame(),
                    left_index=True, right_index=True)
final_df.to_csv(Path(output_dir) / "unique_variance.csv")