In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from glob import glob
import pickle

from PIL import ImageColor
import matplotlib.colors
from sklearn.preprocessing import StandardScaler
from utils import LATENT_FEATURES_DIR, RESULTS_DIR, SUBJECTS, FMRI_BETAS_SURFACE_DIR, STIM_INFO_PATH, COCO_IMAGES_DIR, METRIC_CROSS_DECODING, DECODER_ADDITIONAL_TEST_OUT_DIR, SUBJECTS_ADDITIONAL_TEST, FMRI_DATA_DIR
from analyses.decoding.ridge_regression_decoding import NUM_CV_SPLITS, pairwise_accuracy
from data import MODALITY_AGNOSTIC, MODALITY_SPECIFIC_IMAGES, MODALITY_SPECIFIC_CAPTIONS, TRAINING_MODES, CAPTION, IMAGE, SPLIT_TRAIN, TEST_SPLITS, SPLIT_IMAGERY, SPLIT_IMAGERY_WEAK, TEST_IMAGES, TEST_CAPTIONS, LatentFeatsConfig, get_stim_info, get_latents_for_splits, standardize_latents
from eval import ACC_MODALITY_AGNOSTIC, ACC_CAPTIONS, ACC_IMAGES, ACC_CROSS_IMAGES_TO_CAPTIONS, ACC_CROSS_CAPTIONS_TO_IMAGES, ACC_IMAGERY, ACC_IMAGERY_WHOLE_TEST, get_distance_matrix
from scipy.stats import ttest_rel
import statsmodels.formula.api as smf
from scipy.spatial.distance import cdist

from notebook_utils import load_predictions, get_data_default_feats

sns.set_style("white")
sns.set_context("paper", font_scale=1.5)

In [None]:
def load_results_data():
    data = []

    result_files = sorted(glob(f"{DECODER_ADDITIONAL_TEST_OUT_DIR}/*/*/*/results.csv"))
    for result_file_path in tqdm(result_files):
        results = pd.read_csv(result_file_path)
        data.append(results)

    data = pd.concat(data, ignore_index=True)
    data["mask"] = data["mask"].fillna("whole_brain")

    return data

data = load_results_data()

data = get_data_default_feats(data)

# with pd.option_context('display.max_rows', None, 'display.max_columns', None):
display(data)

print(f"Subjects: {data.subject.unique()}")

In [None]:
print(data[(data.subject == 'sub-01') & (data.model == 'imagebind') & (data.training_mode == 'agnostic') & (data['mask'] == 'whole_brain') & (data.standardized_predictions == 'True') & (data.training_splits == 'train')][['metric', 'value', 'latents']]) #
# print(data[(data.subject == 'sub-01') & (data.model == 'imagebind') & (data.training_mode == 'agnostic') & (data['mask'] == 'whole_brain') & (data.training_splits == 'train') & (data.latents == 'all_candidate_latents')][['metric', 'value', 'latents', 'standardized_predictions']]) #

In [None]:
filtered = data.copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
MODEL = 'imagebind'

filtered = filtered[filtered.model == MODEL]
filtered = filtered[filtered.standardized_predictions == 'True']
filtered = filtered[filtered.training_splits == TRAINING_SPLITS]
filtered = filtered[filtered.latents == LATENT_MODE]
filtered = filtered[filtered['mask'] == MASK]
filtered = filtered[filtered.imagery_samples_weight.isna()]
filtered = filtered[filtered.surface == True]

# print(filtered.groupby(['metric', 'training_mode']).agg(num_subjects=('value', 'size')).reset_index())
NUM_SUBJECTS = len(SUBJECTS_ADDITIONAL_TEST)
expected_len = NUM_SUBJECTS * len(filtered.metric.unique()) * len(filtered.training_mode.unique())
assert len(filtered) == expected_len, filtered

In [None]:
to_plot = filtered.copy()

ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = ['test_image_attended', 'test_image_unattended']#, 'imagery', 'imagery_weak'] #'test_image', 
PALETTE = ['darkblue', 'cornflowerblue'] #blue

to_plot = to_plot[to_plot.metric.isin(HUE_ORDER)]


plt.figure(figsize=(10, 7))
# plt.title('image decoding', y=0.95, fontsize=20)
# with sns.axes_style("white"):
ax = sns.barplot(data=to_plot, x="training_mode", y="value", hue="metric", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)
plt.ylim((0.5, 1))
plt.ylabel('pairwise accuracy')


ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = ['test_image']
to_plot_grouped = filtered.copy()
to_plot_grouped = to_plot_grouped[to_plot_grouped.metric.isin(HUE_ORDER)]
to_plot_grouped = to_plot_grouped.groupby(['training_mode']).agg(value=('value', 'mean')).reset_index()
# display(to_plot)
# with sns.axes_style("white"):
sns.scatterplot(data=to_plot_grouped, x="training_mode", y="value", marker="_", color='black', s=500)

sns.despine()
plt.savefig(os.path.join(RESULTS_DIR, f"attention_modulation_images.png"), bbox_inches='tight', pad_inches=0, dpi=300)

## hypothesis: mod-agnostic decoders (and cross-decoding) should suffer more from missing attention

In [None]:
for_stats = to_plot[to_plot.training_mode.isin(['images', 'agnostic'])]
# display(for_stats)
# print(ttest_rel(for_stats[for_stats.metric == 'test_image_attended'].value, for_stats[for_stats.metric == 'test_image_unattended'].value))

for_stats = for_stats[['model', 'subject', 'training_mode', 'value', 'metric']]

mod = smf.mixedlm("value ~ metric * training_mode", for_stats, groups=for_stats["subject"]).fit()

print(mod.summary())
print('pvalues:\n', mod.pvalues)
print('\n')


In [None]:
to_plot = filtered.copy()

ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = ['test_caption_attended', 'test_caption_unattended']#, 'imagery', 'imagery_weak'] #'test_caption', 
PALETTE = ['darkgreen', 'limegreen'] #'green', 

to_plot = to_plot[to_plot.metric.isin(HUE_ORDER)]

plt.figure(figsize=(10,7))
# plt.title('caption decoding', y=0.95, fontsize=20)
ax = sns.barplot(data=to_plot, x="training_mode", y="value", hue="metric", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.5, 1))

ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = ['test_caption']
to_plot_grouped = filtered.copy()
to_plot_grouped = to_plot_grouped[to_plot_grouped.metric.isin(HUE_ORDER)]
# display(to_plot)

to_plot_grouped = to_plot_grouped.groupby(['training_mode']).agg(value=('value', 'mean')).reset_index()

sns.scatterplot(data=to_plot_grouped, x="training_mode", y="value", marker="_", color='black', s=500)

sns.despine()
plt.savefig(os.path.join(RESULTS_DIR, f"attention_modulation_captions.png"), bbox_inches='tight', pad_inches=0, dpi=300)


In [None]:
for_stats = to_plot[to_plot.training_mode.isin(['captions', 'agnostic'])]
# print(ttest_rel(for_stats[for_stats.metric == 'test_caption_attended'].value, for_stats[for_stats.metric == 'test_caption_unattended'].value))

for_stats = for_stats[['model', 'subject', 'training_mode', 'value', 'metric']]

mod = smf.mixedlm("value ~ metric * training_mode", for_stats, groups=for_stats["subject"]).fit()

print("=" * 50 + "\nGLM\n" + "=" * 50)
print(mod.summary())
print('pvalues:\n', mod.pvalues)
print('\n')

