In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from glob import glob
import pickle

from PIL import ImageColor
import matplotlib.colors

from utils import FEATURES_DIR, RESULTS_DIR, SUBJECTS, NUM_TEST_STIMULI
from analyses.ridge_regression_decoding import NUM_CV_SPLITS, DECODER_OUT_DIR, calc_rsa, calc_rsa_images, calc_rsa_captions, get_fmri_data, pairwise_accuracy, ACC_MODALITY_AGNOSTIC, ACC_CAPTIONS, ACC_IMAGES, get_default_features, get_default_vision_features
from notebook_utils import add_avg_subject, create_result_graph, plot_metric_catplot, plot_metric, load_results_data, ACC_MEAN

In [None]:
all_data = load_results_data()

  3%|▎         | 159/5332 [00:13<09:32,  9.03it/s]

In [None]:
MODEL_ORDER = ["random-flava", "vit-b-16", "vit-l-16", "resnet-18", "resnet-50", "resnet-152", "dino-base", "dino-large", "dino-giant",
               "bert-base-uncased", "bert-large-uncased", "llama2-7b", "llama2-13b", "mistral-7b", "mixtral-8x7b", "gpt2-small", "gpt2-medium", "gpt2-large", "gpt2-xl",
               "visualbert", "bridgetower-large", "clip", "flava", "imagebind", "vilt"]

all_data = all_data[all_data.model.isin(MODEL_ORDER)]
all_data

In [None]:
DEFAULT_FEAT_OPTIONS = ["vision", "lang", "matched"]

def calc_model_feat_order(data, feat_options=DEFAULT_FEAT_OPTIONS):
    all_model_feats = data.model_feat.unique()
    all_models = data.model.unique()
    for model in all_models:
        if model not in MODEL_ORDER:
            raise RuntimeError(f"Model missing in order: {model}")
    model_feat_order = []
    for model in MODEL_ORDER:
        for feats in feat_options:
            model_feat = f"{model}_{feats}"
            if model_feat in all_model_feats:
                model_feat_order.append(model_feat)

    return model_feat_order

multimodal_models = all_data[all_data.features.isin(["matched", "fused_mean", "fused_cls"])].model.unique().tolist()
multimodal_models
# for model in MODEL_ORDER:
#     print(model, end=" ")

In [None]:

data_default_feats = all_data.copy()
for model in all_data.model.unique():
    default_feats = get_default_features(model)
    default_vision_feats = get_default_vision_features(model)
    data_default_feats = data_default_feats[((data_default_feats.model == model) & (data_default_feats.features == default_feats) & (data_default_feats.vision_features == default_vision_feats)) | (data_default_feats.model != model)]
    
data_default_feats

In [None]:
MODEL_FEAT_MULTIMODAL_SINGLE_MODALITY = [m+'_lang' for m in multimodal_models] + [m+'_vision' for m in multimodal_models]
MODEL_FEAT_MULTIMODAL_SINGLE_MODALITY += [m+'_concat' for m in multimodal_models]

vision_models = [m for m in all_data[all_data.features == "vision"].model.unique() if len(all_data[all_data.model == m].features.unique()) == 1]
for m in multimodal_models:
    print(m, end=" ")


### Model performance ranking

In [None]:
model_order = ['random-flava']
for features in DEFAULT_FEAT_OPTIONS:
    print(features)
    dp = data_default_feats.copy()
    dp = dp[dp.features == features]
    dp = dp[dp["mask"] == "whole_brain"]
    dp = dp[dp.training_mode == 'modality-agnostic']
    
    dp = dp[dp.metric == ACC_MODALITY_AGNOSTIC]
    for model in dp.model.unique():
        if len(dp[dp.model == model]) != len(SUBJECTS):
            print(f"unexpected number of datapoints for {model}: {len(dp[dp.model == model])}")
    scores = dp.groupby("model").value.mean().sort_values()
    print(scores)
    model_order.extend(scores.index.values)
    
model_order

## ROI-based decoding

In [None]:
MASK_ORDER = ["high-level visual ROI", "low-level visual ROI", "language ROI"]

MASK_PALETTE = sns.color_palette('Set2')[3:3+len(MASK_ORDER)][::-1]

# data_all_masks = all_data[all_data.model_feat.isin(MODEL_FEATS_INCLUDED)].copy()

MODEL_FEATS_EXCLUDED = ["bridgetower-large_multi", "random-flava_vision", "random-flava_lang"] + MODEL_FEAT_MULTIMODAL_SINGLE_MODALITY
data_all_masks = all_data[~all_data.model_feat.isin(MODEL_FEATS_EXCLUDED)].copy()

data_all_masks = data_all_masks[data_all_masks.vision_features == 'visual_feature_mean']


data_all_masks["mask"] = data_all_masks["mask"].replace({"anatomical_visual_low_level": "low-level visual ROI", "anatomical_lang": "language ROI", "anatomical_visual_high_level": "high-level visual ROI"})

data_all_masks = data_all_masks[data_all_masks["mask"].isin(MASK_ORDER)].copy()

model_feat_order = calc_model_feat_order(data_all_masks)

metrics_order = ["pairwise_acc_captions", "pairwise_acc_images", ACC_MEAN]

dodge = 0.47
# dodge = 0.6
figure, lgd = create_result_graph(data_all_masks, model_feat_order, metrics=metrics_order, row_order=metrics_order, hue_variable="mask", hue_order=MASK_ORDER, palette=MASK_PALETTE, ylim=(0.5, 1),
                                  legend_title="Modality-agnostic decoders trained on fMRI data from", dodge=dodge, legend_bbox=(0.06,0.99))

colors_bg = sns.color_palette('Set2')[:3]
for i in range(len(figure.axes)):
    figure.axes[i, 0].axvspan(-0.5, 0.5, facecolor=colors_bg[2], alpha=0.2, zorder=-100)
    figure.axes[i, 0].axvspan(0.5, 8.5, facecolor=colors_bg[0], alpha=0.2, zorder=-100)
    figure.axes[i, 0].axvspan(8.5, 18.5, facecolor=colors_bg[1], alpha=0.2, zorder=-100)
    figure.axes[i, 0].axvspan(18.5, 25.5, facecolor=colors_bg[2], alpha=0.2, zorder=-100)
plt.xlim((-0.5, 25.5))


# plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0)
plt.savefig(os.path.join(RESULTS_DIR, f"roi_comparison_pairwise_acc.png"), bbox_extra_artists=(lgd,), bbox_inches='tight', pad_inches=0, dpi=300)

In [None]:
n_voxels_data = data_all_masks[~data_all_masks.num_voxels.isna()]
n_voxels_data = {mask: n_voxels_data[n_voxels_data["mask"] == mask].num_voxels.mean() for mask in n_voxels_data["mask"].unique()}
# n_voxels_data.update({
#     # "whole_brain": 214739,
#     "visual_high_level": 14698,
#     "visual_low_level": 13955
# })
print(n_voxels_data)
# sns.barplot(data=n_voxels_data)
# # plt.yscale("log")
# plt.xticks(rotation = 80)
# plt.ylabel("num voxels")
# plt.title("Number of voxels for each mask (whole brain: 214,739)")
# plt.tight_layout()
# plt.savefig(os.path.join(RESULTS_DIR, f"num_voxels.png"), dpi=300)