In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from glob import glob
import pickle

from PIL import ImageColor
import matplotlib.colors

from utils import NN_FEATURES_DIR, RESULTS_DIR, SUBJECTS, NUM_TEST_STIMULI
from analyses.ridge_regression_decoding import NUM_CV_SPLITS, DECODER_OUT_DIR, calc_rsa, calc_rsa_images, calc_rsa_captions, get_fmri_data, pairwise_accuracy, ACC_CAPTIONS, ACC_IMAGES, get_default_features, get_default_vision_features, get_default_lang_features, MOD_SPECIFIC_IMAGES, MOD_SPECIFIC_CAPTIONS
from notebook_utils import add_avg_subject, create_result_graph, plot_metric_catplot, plot_metric, load_results_data, ACC_MEAN

# Zero-shot cross-modal decoding

In [2]:
models = ["random-flava", "clip", "flava", "imagebind", "blip2"]
all_data = load_results_data(models, recompute_acc_scores=False)


100%|██████████| 468/468 [00:00<00:00, 836.63it/s]


In [3]:
data_models = all_data.copy()

data_models = data_models[data_models["mask"] == "whole_brain"]
data_models = data_models[data_models.surface == False]

data_models

Unnamed: 0,alpha,model,subject,features,test_features,vision_features,lang_features,training_mode,mask,num_voxels,surface,resolution,metric,value,model_feat
0,100000.0,blip2,sub-01,avg,avg,vision_features_mean,lang_features_mean,modality-agnostic,whole_brain,132633,False,fsaverage,pairwise_acc_modality_agnostic,0.875983,blip2_avg
1,100000.0,blip2,sub-01,avg,avg,vision_features_mean,lang_features_mean,modality-agnostic,whole_brain,132633,False,fsaverage,pairwise_acc_captions,0.847619,blip2_avg
2,100000.0,blip2,sub-01,avg,avg,vision_features_mean,lang_features_mean,modality-agnostic,whole_brain,132633,False,fsaverage,pairwise_acc_images,0.944099,blip2_avg
3,100000.0,blip2,sub-01,avg,avg,vision_features_mean,lang_features_mean,modality-agnostic,whole_brain,132633,False,fsaverage,pairwise_acc_cross_images_to_captions,0.944513,blip2_avg
4,100000.0,blip2,sub-01,avg,avg,vision_features_mean,lang_features_mean,modality-agnostic,whole_brain,132633,False,fsaverage,pairwise_acc_cross_captions_to_images,0.847205,blip2_avg
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
715,1000000.0,random-flava,sub-07,avg,avg,vision_features_mean,lang_features_mean,images,whole_brain,123559,False,fsaverage,pairwise_acc_cross_images_to_captions,0.697516,random-flava_avg
716,1000000.0,random-flava,sub-07,avg,avg,vision_features_mean,lang_features_mean,images,whole_brain,123559,False,fsaverage,pairwise_acc_cross_captions_to_images,0.537681,random-flava_avg
717,1000000.0,random-flava,sub-07,avg,avg,vision_features_mean,lang_features_mean,images,whole_brain,123559,False,fsaverage,pairwise_acc_imagery,1.000000,random-flava_avg
718,1000000.0,random-flava,sub-07,avg,avg,vision_features_mean,lang_features_mean,images,whole_brain,123559,False,fsaverage,pairwise_acc_imagery_whole_test_set,0.755869,random-flava_avg


In [4]:
data_cls_feats = data_models.copy()
for model in all_data.model.unique():
    data_cls_feats = data_cls_feats[((data_cls_feats.model == model) & (data_cls_feats.vision_features == "vision_features_cls") & (data_cls_feats.lang_features == "lang_features_cls")) | (data_cls_feats.model != model)]
    
data_matched_feats = data_cls_feats[data_cls_feats.features == "matched"]
data_matched_feats

Unnamed: 0,alpha,model,subject,features,test_features,vision_features,lang_features,training_mode,mask,num_voxels,surface,resolution,metric,value,model_feat


In [5]:
def add_mean_cross_modal_and_within_modal_rows(data):
    extra_rows = []
    for model in data.model.unique():
        for mask in data["mask"].unique():
            for subject in SUBJECTS:
                data_model_subj = data[(data.model == model) & (data.subject == subject)]
                if pd.isna(mask):
                    data_model_subj = data_model_subj[pd.isna(data_model_subj['mask'])]
                else:
                    data_model_subj = data_model_subj[data_model_subj['mask'] == mask]
                # cross-modal
                cross_modal_train_images_eval_captions = data_model_subj[(data_model_subj.training_mode == "images") & (data_model_subj.metric == ACC_CAPTIONS)]
                cross_modal_train_captions_eval_images = data_model_subj[(data_model_subj.training_mode == "captions") & (data_model_subj.metric == ACC_IMAGES)]
        
                if len(cross_modal_train_captions_eval_images) > 0:
                    assert len(cross_modal_train_images_eval_captions) == len(cross_modal_train_captions_eval_images) == 1
                    mean_acc = (cross_modal_train_images_eval_captions.value.item() + cross_modal_train_captions_eval_images.value.item()) / 2
            
                    mean_row = cross_modal_train_images_eval_captions.copy()
                    mean_row["training_mode"] = "cross-modal"
                    mean_row["metric"] = "mean"
                    mean_row["value"] = mean_acc
                    mean_row["condition"] = "cross-modal"
        
                    extra_rows.append(mean_row)
        
                # within-modal
                within_modal_captions = data_model_subj[(data_model_subj.training_mode == "captions") & (data_model_subj.metric == ACC_CAPTIONS)]
                within_modal_images = data_model_subj[(data_model_subj.training_mode == "images") & (data_model_subj.metric == ACC_IMAGES)]
        
                if len(within_modal_captions) > 0:
                    assert len(within_modal_captions) == len(within_modal_images) == 1
                    mean_acc = (within_modal_captions.value.item() + within_modal_images.value.item()) / 2
            
                    mean_row = within_modal_captions.copy()
                    mean_row["training_mode"] = "within-modal"
                    mean_row["metric"] = "mean"
                    mean_row["value"] = mean_acc
                    mean_row["condition"] = "within-modal"
        
                    extra_rows.append(mean_row)

    if len(extra_rows) > 0:
        extra_rows = pd.concat(extra_rows)
        data = pd.concat((data, extra_rows), ignore_index=True)
    return data
    # data_matched_with_mean[data_matched_with_mean.metric == "mean"]

In [6]:
data_matched_feats = add_mean_cross_modal_and_within_modal_rows(data_matched_feats)

In [7]:
DEFAULT_FEAT_OPTIONS = ["vision", "lang", "matched"]

def calc_model_feat_order(data, model_order, feat_options=DEFAULT_FEAT_OPTIONS):
    all_model_feats = data.model_feat.unique()
    all_models = data.model.unique()
    for model in all_models:
        if model not in model_order:
            raise RuntimeError(f"Model missing in order: {model}")
    model_feat_order = []
    for model in model_order:
        for feats in feat_options:
            model_feat = f"{model}_{feats}"
            if model_feat in all_model_feats:
                model_feat_order.append(model_feat)

    return model_feat_order

In [8]:
def create_zero_shot_cross_modal_plot(data, x_variable, ylim=(0.5, 1), y_variable="value", ylabel="pairwise_acc", row_variable="metric", hue_variable="condition", title=None):
    sns.set(font_scale=1.6)
    TRAIN_MODE_ORDER = ["images", "captions", "modality-agnostic"]
    FEAT_ORDER = ["vision models", "language models", "multimodal models"]
    
    data_to_plot = data.copy()
    
    data_to_plot = data_to_plot[data_to_plot.training_mode != "modality-agnostic"]
    
    data_to_plot["features"] = data_to_plot.features.replace({"vision": "vision models", "lang": "language models", "matched": "multimodal models"})
    
    data_to_plot.loc[((data_to_plot.training_mode == "images") & (data_to_plot.metric == ACC_CAPTIONS)) | ((data_to_plot.training_mode == "captions") & (data_to_plot.metric == ACC_IMAGES)), "condition"] = "cross-modal"
    data_to_plot.loc[((data_to_plot.training_mode == "captions") & (data_to_plot.metric == ACC_CAPTIONS)) | ((data_to_plot.training_mode == "images") & (data_to_plot.metric == ACC_IMAGES)), "condition"] = "within-modal"
    
    # model_feat_order = calc_model_feat_order(data_to_plot, model_order)
    
    
    metrics_order = [ACC_CAPTIONS, ACC_IMAGES, "mean"]
    
    height = 4.5
    aspect = 4
    
    condition_order = ["cross-modal", "within-modal"]
    
    for mode in ["captions", "images", "cross-modal", "within-modal"]:
        data_mode = data_to_plot[data_to_plot.training_mode == mode]
        for x_variable_value in data_to_plot[x_variable].unique():
            for condition in condition_order:
                length = len(data_mode[(data_mode[x_variable] == x_variable_value) & (data_mode.condition == condition)])
                expected_num_datapoints = len(SUBJECTS)
                if (length > 0) and (length != expected_num_datapoints):
                    message = f"unexpected number of datapoints: {length} (expected: {expected_num_datapoints}) (model_feat: {model} {mode})"
                    print(f"Warning: {message}")
    
    g = sns.catplot(data_to_plot, kind="bar", x=x_variable, y=y_variable, row=row_variable, row_order=metrics_order, col=None, height=height, aspect=aspect, hue=hue_variable, hue_order=condition_order,
                    palette=None, err_kws={'linewidth': 0.5, 'alpha': 0.99}, width=0.7)

    g.set(ylim=ylim, ylabel=ylabel, xlabel='')
    g.tick_params(axis='x', rotation=80)
    if title:
        g.fig.suptitle(title, y=1.03)
    return data_to_plot

In [9]:
data_plotted = create_zero_shot_cross_modal_plot(data_matched_feats, "model", ylim=(0.3, 1))
plt.savefig(os.path.join(RESULTS_DIR, f"zero_shot_cross_modal.png"), bbox_inches='tight', pad_inches=0, dpi=300)

print(data_plotted[(data_plotted.metric == "mean")].groupby(["model", "condition"])['value'].mean())
print(data_plotted[(data_plotted.metric == "mean")].groupby(["model", "condition"])['value'].count())


ValueError: cannot set a frame with no defined index and a scalar

# Zero-shot cross-modal decoding with Mask

In [None]:
def create_zero_shot_cross_modal_masks_plot(model):
    model_order = [model]
    resolution = "fsaverage7"
    
    data_models = all_data[all_data.model.isin(model_order)].copy()
    
    include_masks = [mask_name for mask_name in data_models["mask"].unique() if not "masks_400" in mask_name and "imagebind" in mask_name] #"thresh_0.001" in mask_name and 
    include_masks += ["whole_brain"]
    data_models = data_models[data_models["mask"].isin(include_masks)]
    # print(data_models["mask"].unique())

    data_models["mask"] = data_models["mask"].apply(lambda x: os.path.basename(x)) #

    
    # data_models = data_models[data_models.resolution == resolution]
    # data_models = data_models[data_models.surface == True]
    
    data_cls_feats = data_models.copy()
    for model in data_models.model.unique():
        data_cls_feats = data_cls_feats[((data_cls_feats.model == model) & (data_cls_feats.vision_features == "vision_features_cls") & (data_cls_feats.lang_features == "lang_features_cls")) | (data_cls_feats.model != model)]
    
    data_matched_feats = data_cls_feats[data_cls_feats.features == "matched"]
    
    data_matched_feats = add_mean_cross_modal_and_within_modal_rows(data_matched_feats)
    data_plotted = create_zero_shot_cross_modal_plot(data_matched_feats, x_variable="mask", title=model)
    # print(data_plotted[(data_plotted.metric == "mean")].groupby(["mask", "condition"]).agg(mean_val=('value', 'mean'), count=('value', 'count')))
    return data_plotted

In [None]:
data_plotted = create_zero_shot_cross_modal_masks_plot("blip2")
create_zero_shot_cross_modal_masks_plot("imagebind")
create_zero_shot_cross_modal_masks_plot("clip")
create_zero_shot_cross_modal_masks_plot("flava")

In [None]:
data_plotted = data_plotted[data_plotted["mask"] != "whole_brain"]
data_plotted = create_zero_shot_cross_modal_plot(data_plotted, x_variable="mask", y_variable="num_voxels", ylim=(0, 2000), ylabel="num_vertices", row_variable=None, hue_variable=None)
_ = plt.title("Mask sizes")
# print(data_plotted[(data_plotted.metric == "mean")].groupby(["mask", "condition"])['num_voxels'].mean())


# Zero-shot cross-modal decoding with GloW

In [None]:
models = ["glow", "glow-contrastive"]

data_models = load_results_data(models)
data_models

In [None]:
data_cls_feats = data_models.copy()
# for model in all_data.model.unique():
#     data_cls_feats = data_cls_feats[((data_cls_feats.model == model) & (data_cls_feats.vision_features == "vision_features_cls") & (data_cls_feats.lang_features == "lang_features_cls")) | (data_cls_feats.model != model)]

data_matched_feats = data_cls_feats[data_cls_feats.features == "matched"]


In [None]:
data_matched_feats = add_mean_cross_modal_and_within_modal_rows(data_matched_feats)

In [None]:
data_plotted = create_zero_shot_cross_modal_plot(data_matched_feats, "model")
plt.savefig(os.path.join(RESULTS_DIR, f"zero_shot_cross_modal_glow.png"), bbox_inches='tight', pad_inches=0, dpi=300)

print(data_plotted[(data_plotted.training_mode == "cross-modal") & (data_plotted.metric == "mean")].groupby("model")['value'].mean())
print(data_plotted[(data_plotted.training_mode == "cross-modal") & (data_plotted.metric == "mean")].groupby("model")['value'].count())
