In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import pandas as pd
import seaborn as sns
from tqdm import tqdm
from glob import glob
import pickle

from PIL import Image, ImageColor
import matplotlib.colors
from sklearn.preprocessing import StandardScaler
from utils import LATENT_FEATURES_DIR, RESULTS_DIR, FMRI_BETAS_SURFACE_DIR, STIM_INFO_PATH, COCO_IMAGES_DIR, METRIC_CROSS_DECODING, FMRI_DATA_DIR, DECODER_ADDITIONAL_TEST_OUT_DIR, FMRI_BIDS_DATA_DIR, SUBJECTS_ADDITIONAL_TEST
from analyses.decoding.ridge_regression_decoding import NUM_CV_SPLITS, pairwise_accuracy, get_run_str, RESULTS_FILE, PREDICTIONS_FILE
from data import MODALITY_AGNOSTIC, MODALITY_SPECIFIC_IMAGES, MODALITY_SPECIFIC_CAPTIONS, TRAINING_MODES, CAPTION, IMAGE, TEST_SPLITS, LatentFeatsConfig, get_stim_info, SPLIT_TRAIN, TEST_IMAGES, TEST_CAPTIONS, SPLIT_IMAGERY, SPLIT_IMAGERY_WEAK, get_latents_for_splits, standardize_latents, IMAGERY
from eval import ACC_MODALITY_AGNOSTIC, ACC_CAPTIONS, ACC_IMAGES, ACC_CROSS_IMAGES_TO_CAPTIONS, ACC_CROSS_CAPTIONS_TO_IMAGES, ACC_IMAGERY, ACC_IMAGERY_WHOLE_TEST, get_distance_matrix
from scipy.stats import ttest_rel

from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity

from notebook_utils import load_predictions, load_betas, get_data_default_feats
from glob import glob

sns.set_style("white")
sns.set_context("paper", font_scale=1.5)
%matplotlib inline  

In [None]:
def load_results_data():
    data = []

    result_files = sorted(glob(f"{DECODER_ADDITIONAL_TEST_OUT_DIR}/*/*/*/results.csv"))
    for result_file_path in tqdm(result_files):
        results = pd.read_csv(result_file_path)
        data.append(results)

    data = pd.concat(data, ignore_index=True)
    data["mask"] = data["mask"].fillna("whole_brain")

    return data

data = load_results_data()

# with pd.option_context('display.max_rows', None, 'display.max_columns', None):
display(data)

print(f"Subjects: {data.subject.unique()}")

In [None]:
filtered = get_data_default_feats(data).copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
MODEL = 'imagebind'

filtered = filtered[filtered.model == MODEL]
filtered = filtered[filtered.standardized_predictions == 'True']
filtered = filtered[filtered.training_splits == TRAINING_SPLITS]
filtered = filtered[filtered.latents == LATENT_MODE]
filtered = filtered[filtered['mask'] == MASK]
filtered = filtered[filtered.imagery_samples_weight.isna()]
filtered = filtered[filtered.surface == True]

# print(filtered.groupby(['metric', 'training_mode']).agg(num_subjects=('value', 'size')).reset_index())
NUM_SUBJECTS = len(SUBJECTS_ADDITIONAL_TEST)
expected_len = NUM_SUBJECTS * len(filtered.metric.unique()) * len(filtered.training_mode.unique())
assert len(filtered) == expected_len, filtered

# Imagery vs imagery (weak)
Imagery (weak) is worse because of larger candidate set. If we use comparable candidate sets the performance for imagery (weak) decoding is better

In [None]:
to_plot = filtered.copy()

ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = ['imagery', 'imagery_weak']
PALETTE = ['red', 'salmon']
# sns.set(font_scale=1.3)
plt.figure(figsize=(17,10))
plt.title('imagery decoding', y=0.95, fontsize=20)

ax = sns.barplot(data=to_plot, x="training_mode", y="value", hue="metric", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.5, 1))
sns.despine()
plt.savefig(os.path.join(RESULTS_DIR, f"imagery_vs_imagery_weak.png"), bbox_inches='tight', pad_inches=0, dpi=300)
plt.show()

# Performance advantage of mod-agnostic decoders over mod-specific for imagery decoding

In [None]:
subj_transform = {
    'sub-01': 'Subject 1',
    'sub-02': 'Subject 2',
    'sub-03': 'Subject 3',
    'sub-04': 'Subject 4',
    'sub-05': 'Subject 5',
    'sub-07': 'Subject 6',
}
def subjects_for_plotting(data):
    return data['subject'].apply(lambda subj: subj_transform[subj])

training_mode_transform = {
    MODALITY_AGNOSTIC: 'Modality-agnostic decoder',
    MODALITY_SPECIFIC_IMAGES: 'Modality-specific (images)',
    MODALITY_SPECIFIC_CAPTIONS: 'Modality-specific (captions)',
}
def decoder_for_plotting(data):
    return data['training_mode'].apply(lambda tm: training_mode_transform[tm])

In [None]:
to_plot = filtered.copy()

DECODER_TYPES = ['Modality-agnostic decoder', 'Modality-specific (images)', 'Modality-specific (captions)']
METRICS = ['imagery_weak']
plt.figure(figsize=(11,5))
to_plot['subject'] = subjects_for_plotting(to_plot)
to_plot['decoder_type'] = decoder_for_plotting(to_plot)

to_plot = to_plot[to_plot.metric.isin(METRICS)]
to_plot = to_plot[to_plot.decoder_type.isin(DECODER_TYPES)]

ax = sns.barplot(data=to_plot, x="decoder_type", y="value", order=DECODER_TYPES, errorbar=None, color='gray')
ax = sns.pointplot(data=to_plot, x="decoder_type", y="value", order=DECODER_TYPES, hue="subject")
lgd = ax.legend(loc='upper left', ncols=5, title='', bbox_to_anchor=(0,1.05), frameon=False)
# lgd.get_frame().set_linewidth(0.0)

plt.ylabel('Pairwise Accuracy')
plt.xlabel('')
plt.ylim((0.5, 1))
sns.despine()
plt.savefig(os.path.join(RESULTS_DIR, f"imagery_weak_decoding_decoder_comparison.png"), bbox_inches='tight', pad_inches=0, dpi=300)

In [None]:
# display(to_plot)
acc_images = to_plot[to_plot.training_mode == 'images'].value
acc_captions = to_plot[to_plot.training_mode == 'captions'].value
acc_agnostic = to_plot[to_plot.training_mode == 'agnostic'].value
assert len(acc_images) == len(SUBJECTS_ADDITIONAL_TEST)
assert len(acc_captions) == len(SUBJECTS_ADDITIONAL_TEST)
assert len(acc_agnostic) == len(SUBJECTS_ADDITIONAL_TEST)
# print(acc_images.mean())
# print(acc_agnostic.mean())
print('ttest (agnostic vs. images): ', ttest_rel(acc_agnostic, acc_images, alternative='greater'))
print('ttest (agnostic vs. captions): ', ttest_rel(acc_agnostic, acc_captions, alternative='greater'))


In [None]:
# to_plot = filtered.copy()

# TRAINING_MODES = ['agnostic']
# METRICS = ['imagery_weak', 'test_caption_attended']
# sns.set(font_scale=1.3)
# plt.figure(figsize=(17,10))
# plt.title('decoding', y=0.95, fontsize=20)

# to_plot = to_plot[to_plot.metric.isin(METRICS)]
# to_plot = to_plot[to_plot.training_mode.isin(TRAINING_MODES)]

# ax = sns.barplot(data=to_plot, x="metric", y="value", errorbar=None, color='gray')
# # ax = sns.scatterplot(data=to_plot, x="subject", y="value", hue="training_mode")
# ax = sns.pointplot(data=to_plot, x="metric", y="value", hue="subject")

# plt.ylabel('pairwise accuracy')
# plt.ylim((0.5, 1.1))
# plt.savefig(os.path.join(RESULTS_DIR, f"attention_modulation_imagery.png"), bbox_inches='tight', pad_inches=0, dpi=300)

# display(to_plot)


## For some participants, mod-specific decoders trained on captions do not generalize well to imagery:

In [None]:
to_plot = filtered.copy()

TRAINING_MODES = ['images', 'captions']
METRICS = ['imagery_weak']
# sns.set(font_scale=1.3)
plt.figure(figsize=(17,10))
plt.title('imagery (weak) decoding', y=0.95, fontsize=20)

to_plot = to_plot[to_plot.metric.isin(METRICS)]
to_plot = to_plot[to_plot.training_mode.isin(TRAINING_MODES)]

ax = sns.barplot(data=to_plot, x="training_mode", y="value", errorbar=None, color='gray')
# ax = sns.scatterplot(data=to_plot, x="subject", y="value", hue="training_mode")
ax = sns.pointplot(data=to_plot, x="training_mode", y="value", hue="subject")

plt.ylabel('pairwise accuracy')
plt.ylim((0.5, 1))
sns.despine()



# Imagery (weak) decoding with mask

In [None]:
to_plot = get_data_default_feats(data).copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
TRAINING_SPLITS = 'train'
SUBJECTS = SUBJECTS_ADDITIONAL_TEST
MODEL = 'imagebind'

to_plot = to_plot[to_plot.model == MODEL]
to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery_weak']
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot.subject.isin(SUBJECTS)]

to_plot['mask'] = to_plot['mask'].apply(lambda x: os.path.basename(x))
print(to_plot['mask'].unique())
# MASKS = ['mod_agnostic_and_cross_threshold_0.01.p', 'mod_agnostic_and_cross_threshold_0.0001.p', 'whole_brain']
# to_plot = to_plot[to_plot['mask'].isin(MASKS)]


assert len(to_plot) == len(SUBJECTS) * len(to_plot['mask'].unique()) 

ORDER = None
HUE_ORDER = [ 
    'random_1000',  'captions$test_caption_attended_1000_vertices.p', 'images$test_image_attended_1000_vertices.p',               'mod_invariant_increase_1000_vertices.p', 'mod_invariant_attended_1000_vertices.p',
'random_10000', 'captions$test_caption_attended_10000_vertices.p',  'images$test_image_attended_10000_vertices.p',               'mod_invariant_increase_10000_vertices.p', 'mod_invariant_attended_10000_vertices.p',
    'random_100000', 'captions$test_caption_attended_100000_vertices.p',  'images$test_image_attended_100000_vertices.p',               'mod_invariant_increase_100000_vertices.p', 'mod_invariant_attended_100000_vertices.p',
 # 'mod_agnostic_and_cross_lh_threshold_0.0001_cluster_0.p'
 # 'mod_agnostic_and_cross_rh_threshold_0.0001_cluster_0.p'
 # 'mod_agnostic_and_cross_rh_threshold_0.0001_cluster_1.p'
 # 'mod_agnostic_and_cross_threshold_0.0001.p'
 # 'mod_agnostic_and_cross_threshold_0.01.p'
 # 'mod_agnostic_and_cross_threshold_200000.0.p'
    'whole_brain']

PALETTE = 'tab10'
sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))
plt.title('imagery (weak) decoding', y=0.95, fontsize=20)

# ax = sns.barplot(data=to_plot, x='subject', y="value", hue="mask", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)
ax = sns.barplot(data=to_plot, x=None, y="value", hue="mask", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)

plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
plt.savefig(os.path.join(RESULTS_DIR, f"imagery_weak_decoding_based_on_modality_agnostic_regions.png"), bbox_inches='tight', pad_inches=0, dpi=300)


display(to_plot.groupby(['mask']).agg(value=('value', 'mean'),n_vertices=('num_voxels', 'mean')).reset_index())


acc_mask = to_plot[to_plot['mask']== 'mod_agnostic_and_cross_threshold_0.01.p'].value
acc_whole_brain = to_plot[to_plot['mask'] == 'whole_brain'].value
# print(acc_mask.mean())
# print(acc_whole_brain.mean())
# ttest_rel(acc_mask, acc_whole_brain, alternative='greater')
plt.show()

In [None]:
def get_distance_matrix(predictions, originals, metric='cosine'):
    dist = cdist(predictions, originals, metric=metric)
    return dist
    
def dist_mat_to_pairwise_acc(dist_mat, stim_ids, print_details=False, reduce=True):
    if reduce:
        diag = dist_mat.diagonal().reshape(-1, 1)
        comp_mat = diag < dist_mat
        corrects = comp_mat.sum()
        if print_details:
            for i, stim_id in enumerate(stim_ids):
                print(stim_id, end=': ')
                print(f'{comp_mat[i].sum() / (len(comp_mat[i]) - 1):.2f}')
        # subtract the number of elements of the diagonal as these values are always "False" (not smaller than themselves)
        score = corrects / (dist_mat.size - diag.size)
        return score
    else:
        diag = dist_mat.diagonal().reshape(-1, 1)
        comp_mat = diag < dist_mat 
        # print(diag.shape)
        # print(diag)
        acc_per_pred = comp_mat.sum(axis=1) / (len(dist_mat) - 1)

        return acc_per_pred



In [None]:
SURFACE = True

MODEL = "imagebind"

SUBJECTS = SUBJECTS_ADDITIONAL_TEST

TRAINING_MODES = ["images", "agnostic"]

BETAS_SUFFIX = 'betas'
BETAS_DIR = os.path.join(FMRI_DATA_DIR, BETAS_SUFFIX)

RESTANDARDIZE_PREDS = [SPLIT_IMAGERY_WEAK]

FEATS = 'default'
TEST_FEATS = 'default'
VISION_FEATS = 'default'
LANG_FEATS = 'default'
FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)

all_pairwise_accs = []
for subj in SUBJECTS:
    print(subj)
    for training_mode in TRAINING_MODES:
    
        stim_ids_imagery, _ =  get_stim_info(subj, SPLIT_IMAGERY_WEAK)
    
        latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN, TEST_IMAGES, SPLIT_IMAGERY_WEAK], training_mode)
        latents = standardize_latents(latents)
    
        predictions = load_predictions(BETAS_DIR, subj, training_mode, FEATS_CONFIG, surface=SURFACE)
    
        pred_latents_imagery = predictions[SPLIT_IMAGERY_WEAK]
        if len(RESTANDARDIZE_PREDS)>0:
            print(f'standardizing imagery predictions with {RESTANDARDIZE_PREDS}')
            refs = np.concatenate([predictions[split] for split in RESTANDARDIZE_PREDS])
            transform = StandardScaler().fit(refs)
            pred_latents_imagery = transform.transform(pred_latents_imagery)

        candidate_latents = latents[SPLIT_IMAGERY_WEAK]
        candidate_latent_ids = stim_ids_imagery
        
        dist_mat = get_distance_matrix(pred_latents_imagery, candidate_latents)
        scores = dist_mat_to_pairwise_acc(dist_mat, candidate_latent_ids, reduce=False)
        # print(scores.mean())
        for stim_id, score in zip(stim_ids_imagery, scores):
            all_pairwise_accs.append({'value': score, 'training_mode': training_mode, 'subject': subj, 'stim_id': stim_id})
    

df = pd.DataFrame(all_pairwise_accs)
df['subj_stim_id'] = df['subject'] + '_' + df['stim_id'].astype("string")
# display(df)
values_agnostic = df[df.training_mode == 'agnostic'].sort_values(['subject', 'stim_id'])
values_images = df[df.training_mode == 'images'].sort_values(['subject', 'stim_id'])
display(values_agnostic)
display(values_images)

# plt.figure(figsize=(40, 40))
# ax = sns.pointplot(data=df, x="training_mode", y="value", hue="subj_stim_id")

ttest_rel(values_images.value, values_agnostic.value, alternative='less')

# Varying the target features for Imagery (weak) decoding
Imagebind avg features (average of vision and lang) seem to work best. Using only lang features is almost as good.

In [None]:
to_plot = data.copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
SUBJECTS = SUBJECTS_ADDITIONAL_TEST

to_plot = to_plot[to_plot.model == 'imagebind']
to_plot = to_plot[(to_plot.test_features == 'avg') | (to_plot.features != 'avg')]
to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery_weak']
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot['mask'] == MASK]
to_plot = to_plot[to_plot.subject.isin(SUBJECTS)]


FEAT_OPTIONS = ['avg', 'lang', 'vision']
display(to_plot)

assert len(to_plot) == len(SUBJECTS) * len(to_plot.features.unique())

sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))
plt.title('imagery (weak) decoding', y=0.95, fontsize=20)

ax = sns.barplot(data=to_plot, x="features", y="value", order=FEAT_OPTIONS)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
plt.savefig(os.path.join(RESULTS_DIR, f"imagery_weak_feat_comparison.png"), bbox_inches='tight', pad_inches=0, dpi=300)


# Varying the target features for Imagery (weak) decoding at test time
Imagebind avg and lang features work equally well.

In [None]:
to_plot = data.copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
SUBJECTS = SUBJECTS_ADDITIONAL_TEST

to_plot = to_plot[to_plot.model == 'imagebind']
to_plot = to_plot[to_plot.features == 'avg']
to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery_weak']
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot['mask'] == MASK]
to_plot = to_plot[to_plot.subject.isin(SUBJECTS)]


FEAT_OPTIONS = ['avg', 'lang']
display(to_plot)

assert len(to_plot) == len(SUBJECTS) * len(to_plot.test_features.unique())

sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))
plt.title('imagery (weak) decoding', y=0.95, fontsize=20)

ax = sns.barplot(data=to_plot, x="test_features", y="value", order=FEAT_OPTIONS)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
plt.savefig(os.path.join(RESULTS_DIR, f"imagery_weak_test_feat_comparison.png"), bbox_inches='tight', pad_inches=0, dpi=300)


In [None]:
to_plot = get_data_default_feats(data).copy()

# LATENT_MODE = 'all_candidate_latents'
LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
SUBJECTS = SUBJECTS_ADDITIONAL_TEST

to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery_weak']
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot['mask'] == MASK]
to_plot = to_plot[to_plot.subject.isin(SUBJECTS)]

display(to_plot)

assert len(to_plot) == len(SUBJECTS) * len(to_plot.model.unique())

sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))

ax = sns.barplot(data=to_plot, x="model", y="value", order=None)
plt.title('imagery (weak) decoding', y=0.95, fontsize=20)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
# plt.savefig(os.path.join(RESULTS_DIR, f"imager_weak_feat_comparison.png"), bbox_inches='tight', pad_inches=0, dpi=300)


# Imagery decoding with varying standardization techniques

In [None]:
to_plot = get_data_default_feats(data).copy()

LATENT_MODE = 'all_candidate_latents'
# LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'
MODEL = 'imagebind'

to_plot = to_plot[to_plot.model == MODEL]
# to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery']
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot['mask'] == MASK]

assert len(to_plot) == NUM_SUBJECTS * len(to_plot.standardized_predictions.unique()) * len(to_plot.training_mode.unique())

ORDER = ['captions', 'images', 'agnostic']
HUE_ORDER = None

PALETTE = 'tab10'
sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))
plt.title('imagery decoding', y=0.95, fontsize=20)

ax = sns.barplot(data=to_plot, x="training_mode", y="value", hue="standardized_predictions", order=ORDER, hue_order=HUE_ORDER, palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
# plt.savefig(os.path.join(RESULTS_DIR, f"vary_standardization.png"), bbox_inches='tight', pad_inches=0, dpi=300)


# Imagery decoding with varying training sets

In [None]:
to_plot = get_data_default_feats(data).copy()


LATENT_MODE = 'all_candidate_latents'
# LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train'

print(to_plot.training_splits.unique())
MODEL = 'imagebind'
to_plot = to_plot[to_plot.model == MODEL]
# to_plot = to_plot[to_plot.standardized_predictions == 'True']
to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery']
# to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]
to_plot = to_plot[to_plot['mask'] == MASK]
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot.imagery_samples_weight.isna()]


# display(to_plot)
# assert len(to_plot) == NUM_SUBJECTS * len(to_plot.standardized_predictions.unique()) * len(to_plot.training_splits.unique()), to_plot

PALETTE = 'tab10'#['red', 'salmon']
sns.set(font_scale=1.3)
plt.figure(figsize=(20,10))
plt.title('imagery decoding', y=0.95, fontsize=20)

ax = sns.barplot(data=to_plot, x="training_splits", y="value", hue="standardized_predictions", palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
plt.savefig(os.path.join(RESULTS_DIR, f"attention_modulation_imagery.png"), bbox_inches='tight', pad_inches=0, dpi=300)


# Imagery decoding with varying training sets and sample weight

In [None]:
to_plot = data.copy()


LATENT_MODE = 'all_candidate_latents'
# LATENT_MODE = 'limited_candidate_latents'
MASK = 'whole_brain'
TRAINING_SPLITS = 'train_imagery_weak'

print(to_plot.training_splits.unique())

to_plot = to_plot[to_plot.latents == LATENT_MODE]
to_plot = to_plot[to_plot.metric == 'imagery']
to_plot = to_plot[to_plot['mask'] == MASK]
to_plot = to_plot[to_plot.surface == True]
to_plot = to_plot[to_plot.training_mode == 'agnostic']
to_plot = to_plot[to_plot.standardized_predictions == 'all_imagery']

# print(to_plot[to_plot.training_splits == 'train'])

to_plot = to_plot[to_plot.training_splits == TRAINING_SPLITS]


to_plot.imagery_samples_weight.fillna(1.0, inplace=True)

sample_weights = [1, 10, 100, 200, 500, 1000]
to_plot = to_plot[to_plot.imagery_samples_weight.isin(sample_weights)]

print(to_plot.imagery_samples_weight.unique())
# display(to_plot)
assert len(to_plot) == NUM_SUBJECTS * len(to_plot.standardized_predictions.unique()) * len(to_plot.imagery_samples_weight.unique()), len(to_plot)

HUE_ORDER = np.sort(to_plot.imagery_samples_weight.unique())

PALETTE = 'tab10'#['red', 'salmon']
sns.set(font_scale=1.3)

plt.figure(figsize=(20,10))
plt.title('imagery decoding', y=0.95, fontsize=20)
ax = sns.barplot(data=to_plot, y="value", hue="imagery_samples_weight", hue_order=HUE_ORDER, palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))
plt.axhline(y=0.78)

plt.figure(figsize=(20,10))
plt.title('imagery decoding', y=0.95, fontsize=20)
ax = sns.barplot(data=to_plot, x="subject", y="value", hue="imagery_samples_weight", hue_order=HUE_ORDER, palette=PALETTE)
plt.ylabel('pairwise accuracy')
plt.ylim((0.3, 1))


# Qualitative analyses

In [None]:
from feature_extraction.feat_extraction_utils import CoCoDataset
from pdf2image import convert_from_path


In [None]:
coco_ds = CoCoDataset(COCO_IMAGES_DIR, STIM_INFO_PATH)

In [None]:
def resize_img(image, length=100):
    if image.size[0] < image.size[1]:
        resized_image = image.resize((length, int(image.size[1] * (length / image.size[0]))))
        required_loss = (resized_image.size[1] - length)
        resized_image = resized_image.crop(box=(0, required_loss / 2, length, resized_image.size[1] - required_loss / 2))
    else:
        resized_image = image.resize((int(image.size[0] * (length / image.size[1])), length))
        required_loss = resized_image.size[0] - length
        resized_image = resized_image.crop(box=(required_loss / 2, 0, resized_image.size[0] - required_loss / 2, length))
    return resized_image

def display_stimuli(coco_ids, imgs=True, caps=True):
    if caps:
        for coco_id in coco_ids:
            print(coco_ds.captions[coco_id], end="\n")

    if imgs:
        imgs = [np.array(resize_img(coco_ds.get_img_by_coco_id(img_id))) for img_id in coco_ids]        
        img = Image.fromarray(np.hstack(imgs))
        display(img)

def get_distance_matrix(predictions, originals, metric='cosine'):
    dist = cdist(predictions, originals, metric=metric)
    return dist
    
def dist_mat_to_pairwise_acc(dist_mat, stim_ids, print_details=False):
    diag = dist_mat.diagonal().reshape(-1, 1)
    comp_mat = diag < dist_mat
    corrects = comp_mat.sum()
    if print_details:
        for i, stim_id in enumerate(stim_ids):
            print(stim_id, end=': ')
            print(f'{comp_mat[i].sum() / (len(comp_mat[i]) - 1):.2f}')
    # subtract the number of elements of the diagonal as these values are always "False" (not smaller than themselves)
    score = corrects / (dist_mat.size - diag.size)
    return score

def dist_mat_to_rankings(dist_mat, stim_ids, candidate_set_latent_ids):
    all_ranks = []
    for test_stimulus_id, nneighbors_row in zip(stim_ids, dist_mat):
        nneighbors_ids = np.array(candidate_set_latent_ids)[np.argsort(nneighbors_row)]
        rank = np.argwhere(nneighbors_ids == test_stimulus_id)[0][0] + 1

        all_ranks.append(rank)
            
    return np.mean(all_ranks)


# Nearest Neighbors of imagery images

In [None]:
FONTSIZE = 13


def plot_nn_table(stim_ids, nneighbors, subject, stim_type, out_file_name=None, img_length=150, hspace=0.2, wspace=0.05, similarities=None, similarties_pred_to_top_ranks=None):
    if stim_type == IMAGERY:
        stimulus_key = 'Imagery sketch and initial\ninstruction'
    elif stim_type == SPLIT_IMAGERY_WEAK:
        stimulus_key = 'Ground Truth'
    else:
        stimulus_key = 'Stimulus'

    figsize=(15,15)
    
    df = pd.DataFrame({stimulus_key: stim_ids} | {f'rank {i}': [n[i] for n in nneighbors] for i in range(len(nneighbors[0]))})

    n_columns = len(nneighbors[0])+1
    fig, axes = plt.subplots(len(stim_ids),n_columns, figsize=figsize) #, layout="constrained"

    if stim_type == IMAGERY:
        fig.subplots_adjust(wspace=wspace, hspace=hspace, top=0.97, bottom=0.06, left=0.01, right=0.99)   
    else:
        fig.subplots_adjust(wspace=wspace, hspace=hspace, top=0.98, bottom=0.03, left=0.01, right=0.99)  
       

    for idx, (stim_id, neighbors, similarity, sim_ranks) in enumerate(zip(stim_ids, nneighbors, similarities, similarties_pred_to_top_ranks)):
        caption = coco_ds.captions[stim_id].lower()
        img = resize_img(coco_ds.get_img_by_coco_id(stim_id), length=img_length)

        if stim_type == IMAGE:
            axes[idx][0].imshow(img)
        elif stim_type == SPLIT_IMAGERY_WEAK:
            axes[idx][0].imshow(img)
            txt = axes[idx][0].text(0, 155, caption, ha='left', wrap=True, fontsize=FONTSIZE, verticalalignment='top', 
                                   bbox=dict(boxstyle='square,pad=0', facecolor='none', edgecolor='none'))
            txt._get_wrap_line_width = lambda : img_length*4
        elif stim_type in [CAPTION]:
            img = Image.fromarray(np.full((img_length, img_length, 3), 255, dtype=np.uint8), "RGB")
            axes[idx][0].imshow(img)
            txt = axes[idx][0].text(0, img_length/2, caption+'\n\n'+f'(cosine distance: {np.round(distance, 2)})', ha='left', wrap=True, fontsize=FONTSIZE,# verticalalignment='top', 
                                   bbox=dict(boxstyle='square,pad=0', facecolor='none', edgecolor='none'))
            txt._get_wrap_line_width = lambda : img_length*4
        elif stim_type == IMAGERY:
            drawing_path = os.path.join(FMRI_BIDS_DATA_DIR, "stimuli", "imagery_drawings", f"{subject}_imagery_{idx+1}.pdf")
            img_drawing = convert_from_path(drawing_path)[0]
            img_drawing = resize_img(img_drawing, length=img_length)
            
            axes[idx][0].imshow(img_drawing)
            txt = axes[idx][0].text(0, 155, caption, ha='left', wrap=True, fontsize=FONTSIZE, verticalalignment='top',
                       bbox=dict(boxstyle='square,pad=0', facecolor='none', edgecolor='none'))
            txt._get_wrap_line_width = lambda : img_length*4
            # axes[idx][1].axis('off')
            
        axes[idx][0].axis('off')

        if stim_type == SPLIT_IMAGERY_WEAK:
            axes[idx][0].set_title(f'{stimulus_key} (cos={similarity:.2f})', fontweight="bold", fontsize=12.5)
        else:
            if idx == 0:
                axes[idx][0].set_title(f'{stimulus_key}', fontweight="bold")
    

        for n_id, (neighbor_id, sim) in enumerate(zip(neighbors, sim_ranks)):
            caption = coco_ds.captions[neighbor_id].lower()
            img = resize_img(coco_ds.get_img_by_coco_id(neighbor_id), length=img_length)
            axes[idx][n_id+1].imshow(img)
            axes[idx][n_id+1].axis('off')
            if stim_type == SPLIT_IMAGERY_WEAK:
               axes[idx][n_id+1].set_title(f'Rank {n_id} (cos={sim:.2f})', fontweight="bold", fontsize=12.5)
            else:
                if idx == 0:
                    axes[idx][n_id+1].set_title(f'Rank {n_id} (cos={sim:.2f})', fontweight="bold")

            txt = axes[idx][n_id+1].text(0, 155, caption, ha='left', wrap=True, fontsize=FONTSIZE, verticalalignment='top',
                                   bbox=dict(boxstyle='square,pad=0', facecolor='none', edgecolor='none'))
            if stim_type == IMAGERY:
                txt._get_wrap_line_width = lambda : img_length*4
            else:
                txt._get_wrap_line_width = lambda : img_length*4

 
    if out_file_name is not None:
        out_path = os.path.join(RESULTS_DIR, "analysis_ranking", out_file_name)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, dpi=250)

        plt.close(fig)
        img = mpimg.imread(out_path)
        plt.figure(figsize=(15,15))
        # plt.figure(figsize=(3*len(nneighbors),3*len(nneighbors)))
        imgplot = plt.imshow(img)
        plt.axis('off')
        plt.show()
    else:
        plt.show()


def analysis_ranking(test_preds, test_stim_ids, candidate_latents, candidate_latent_ids, subject, stim_type, n_samples=5, num_neighbors=5, out_file_name=None, hspace=0.2, wspace=0.05, subsample='random'):
    sim_mat = cosine_similarity(test_preds, candidate_latents)
    dist_mat = get_distance_matrix(test_preds, candidate_latents)

    acc = dist_mat_to_pairwise_acc(dist_mat, test_stim_ids)
    print(f'pairwise acc: {acc:.3f}')

    distances = None
    if stim_type != IMAGERY:        
        distances_to_target = []
        similarities_pred_to_target = []
        similarties_pred_to_top_ranks = []
        for test_stim_id, dist_row, sim_row in zip(test_stim_ids, dist_mat, sim_mat):
            target_idx = np.where(candidate_latent_ids == test_stim_id)[0]
            distance_to_target = dist_row[target_idx][0]
            sim_to_target = sim_row[target_idx][0]
            distances_to_target.append(distance_to_target)
            similarities_pred_to_target.append(sim_to_target)
            # print(sim_row)
            sim_to_top_ranked = np.sort(sim_row)[-num_neighbors:][::-1]
            
            similarties_pred_to_top_ranks.append(sim_to_top_ranked)
        distances_to_target = np.array(distances_to_target)
        similarities_pred_to_target = np.array(similarities_pred_to_target)

        if subsample == 'best':
            best_indices = np.argsort(similarities_pred_to_target)[-n_samples:][::-1]
            sampled_ids = best_indices
        elif subsample == 'worst':
            worst_indices = np.argsort(similarities_pred_to_target)[:n_samples][::-1]
            sampled_ids = worst_indices
        else:
            np.random.seed(1)
            sampled_ids = np.random.choice(range(len(test_stim_ids)), n_samples, replace=False)
            
        test_stim_ids = np.array(test_stim_ids)[sampled_ids]
        dist_mat = dist_mat[sampled_ids]
        distances = distances_to_target[sampled_ids]
        similarties_pred_to_top_ranks = np.array(similarties_pred_to_top_ranks)[sampled_ids]
        # print('Distances to target:', distances)

        similarities = similarities_pred_to_target[sampled_ids]
        print('Similarities to target:', similarities)
        # print(similarties_pred_to_top_ranks)

   
    nneighbors = [np.array(candidate_latent_ids)[np.argsort(nneighbors_row)][:num_neighbors] for nneighbors_row in dist_mat]
    
    # print(nneighbors)    
    plot_nn_table(test_stim_ids, nneighbors, subject, stim_type, out_file_name, hspace=hspace, wspace=wspace, similarities=similarities, similarties_pred_to_top_ranks=similarties_pred_to_top_ranks)

    return acc




# Nearest neighbors of imagery trials

In [None]:
MODEL = "imagebind"
TRAINING_MODE = "agnostic"

FEATS = 'default'
TEST_FEATS = 'default'
VISION_FEATS = 'default'
LANG_FEATS = 'default'
FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)

all_train_stim_ids = []
all_train_latents = []
for subj in tqdm(SUBJECTS_ADDITIONAL_TEST):
    stim_ids, _ = get_stim_info(subj, SPLIT_TRAIN)
    
    latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN], TRAINING_MODE)
    latents = standardize_latents(latents)

    all_train_stim_ids.append(stim_ids)
    all_train_latents.append(latents[SPLIT_TRAIN])
plt.axis('off')
all_train_stim_ids = np.concatenate(all_train_stim_ids)
all_train_latents = np.concatenate(all_train_latents)

unique_stim_ids, indices = np.unique(all_train_stim_ids, return_index=True)
unique_train_latents = all_train_latents[indices]
print(len(unique_stim_ids))

In [None]:
# plt.ioff()
# N_NEIGHBORS = 5
# N_SAMPLES = 3

# WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET = True
# SURFACE = True

# MODEL = "imagebind"

# TRAINING_MODE = "agnostic"

# MASK = None

# BETAS_SUFFIX = 'betas'

# BETAS_DIR = os.path.join(FMRI_DATA_DIR, BETAS_SUFFIX)

# RESTANDARDIZE_PREDS = [SPLIT_IMAGERY]
# TRAINING_SPLITS = [SPLIT_TRAIN]
# IMAGERY_SAMPLES_WEIGHT = None

# FEATS = 'default'
# TEST_FEATS = 'default'
# VISION_FEATS = 'default'
# LANG_FEATS = 'default'
# FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)

# all_pairwise_accs = []
# for subj in SUBJECTS_ADDITIONAL_TEST:
#     print(subj)

#     stim_ids_test, _ = get_stim_info(subj, TEST_IMAGES)
#     stim_ids_imagery, _ =  get_stim_info(subj, SPLIT_IMAGERY)

#     latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN, TEST_IMAGES, SPLIT_IMAGERY], TRAINING_MODE)
#     latents = standardize_latents(latents)

#     predictions = load_predictions(BETAS_DIR, subj, TRAINING_MODE, FEATS_CONFIG, surface=SURFACE, mask=MASK, training_splits=TRAINING_SPLITS, imagery_samples_weight=IMAGERY_SAMPLES_WEIGHT)

#     pred_latents_imagery = predictions[SPLIT_IMAGERY]
#     if len(RESTANDARDIZE_PREDS)>0:
#         print('standardizing imagery predictions')
#         refs = np.concatenate([predictions[split] for split in RESTANDARDIZE_PREDS])
#         transform = StandardScaler().fit(refs)
#         pred_latents_imagery = transform.transform(pred_latents_imagery)
#         # pred_latents_imagery = StandardScaler().fit_transform(pred_latents_imagery)

#     # test_stim_ids_mod = results['stimulus_ids'][results['stimulus_types'] == IMAGE]
#     # test_latents_mod = results['latents'][results['stimulus_types'] == IMAGE]

#     if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
#         # # account for case that sometimes both the image and the caption are part of the training set
#         # unique_stim_ids, indices = np.unique(stim_ids, return_index=True)
#         # unique_train_latents = train_latents[indices]
#         unique_stim_ids, indices = np.unique(all_train_stim_ids, return_index=True)
#         unique_train_latents = all_train_latents[indices]
#         print('candidate set size: ', len(unique_stim_ids))

#         candidate_latents = np.concatenate((latents[SPLIT_IMAGERY], latents[TEST_IMAGES], unique_train_latents))
#         candidate_latent_ids = np.concatenate((stim_ids_imagery, stim_ids_test, unique_stim_ids))
#     else:
#         candidate_latents = latents[SPLIT_IMAGERY]
#         candidate_latent_ids = stim_ids_imagery
    
#     acc = analysis_ranking(pred_latents_imagery, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, IMAGERY, N_SAMPLES, N_NEIGHBORS, out_file_name=f"{IMAGERY}_{TRAINING_MODE}_decoder_{subj}.png", hspace=0.2)
#     all_pairwise_accs.append(acc)


# print(f'Mean pairwise acc: {np.mean(all_pairwise_accs):.2f}')

In [None]:
# plt.ioff()
# N_NEIGHBORS = 5
# N_SAMPLES = 3

# WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET = True
# SURFACE = True

# MODEL = "imagebind"

# TRAINING_MODE = "agnostic"

# MASK = None

# BETAS_SUFFIX = 'betas'

# BETAS_DIR = os.path.join(FMRI_DATA_DIR, BETAS_SUFFIX)

# RESTANDARDIZE_PREDS = [SPLIT_IMAGERY]
# TRAINING_SPLITS = [SPLIT_TRAIN, SPLIT_IMAGERY_WEAK]
# IMAGERY_SAMPLES_WEIGHT = 500

# FEATS = 'default'
# TEST_FEATS = 'default'
# VISION_FEATS = 'default'
# LANG_FEATS = 'default'
# FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)

# all_pairwise_accs = []
# for subj in SUBJECTS_ADDITIONAL_TEST:
#     print(subj)

#     stim_ids_test, _ = get_stim_info(subj, TEST_IMAGES)
#     stim_ids_imagery, _ =  get_stim_info(subj, SPLIT_IMAGERY)

#     latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN, TEST_IMAGES, SPLIT_IMAGERY], TRAINING_MODE)
#     latents = standardize_latents(latents)

#     predictions = load_predictions(BETAS_DIR, subj, TRAINING_MODE, FEATS_CONFIG, surface=SURFACE, mask=MASK, training_splits=TRAINING_SPLITS, imagery_samples_weight=IMAGERY_SAMPLES_WEIGHT)

#     pred_latents_imagery = predictions[SPLIT_IMAGERY]
#     if len(RESTANDARDIZE_PREDS)>0:
#         print('standardizing imagery predictions')
#         refs = np.concatenate([predictions[split] for split in RESTANDARDIZE_PREDS])
#         print(len(refs))
#         transform = StandardScaler().fit(refs)
#         pred_latents_imagery = transform.transform(pred_latents_imagery)
#         # pred_latents_imagery = StandardScaler().fit_transform(pred_latents_imagery)

#     # test_stim_ids_mod = results['stimulus_ids'][results['stimulus_types'] == IMAGE]
#     # test_latents_mod = results['latents'][results['stimulus_types'] == IMAGE]

#     if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
#         # # account for case that sometimes both the image and the caption are part of the training set
#         # unique_stim_ids, indices = np.unique(stim_ids, return_index=True)
#         # unique_train_latents = train_latents[indices]
#         unique_stim_ids, indices = np.unique(all_train_stim_ids, return_index=True)
#         unique_train_latents = all_train_latents[indices]
#         print('candidate set size: ', len(unique_stim_ids))

#         candidate_latents = np.concatenate((latents[SPLIT_IMAGERY], latents[TEST_IMAGES], unique_train_latents))
#         candidate_latent_ids = np.concatenate((stim_ids_imagery, stim_ids_test, unique_stim_ids))
#     else:
#         candidate_latents = latents[SPLIT_IMAGERY]
#         candidate_latent_ids = stim_ids_imagery

#     out_file_name = f"{IMAGERY}_{TRAINING_MODE}_decoder_{subj}_train_with_imagery_samples_weight_{IMAGERY_SAMPLES_WEIGHT}.png"
#     acc = analysis_ranking(pred_latents_imagery, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, IMAGERY, N_SAMPLES, N_NEIGHBORS, out_file_name=out_file_name, hspace=0.2)
#     all_pairwise_accs.append(acc)

# print(f'Mean pairwise acc: {np.mean(all_pairwise_accs):.2f}')

In [None]:
from data import TEST_IMAGES_ATTENDED, TEST_CAPTIONS_ATTENDED, TEST_IMAGES_UNATTENDED, TEST_CAPTIONS_UNATTENDED

# Weak imagery decoding

In [None]:
N_NEIGHBORS = 5
N_SAMPLES = 5

WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET = False
SURFACE = True
MODEL = "imagebind"

TRAINING_MODE = "agnostic"

MASK = None

BETAS_SUFFIX = 'betas'

BETAS_DIR = os.path.join(FMRI_DATA_DIR, BETAS_SUFFIX)

RESTANDARDIZE_PREDS = [SPLIT_IMAGERY_WEAK]

FEATS = 'default'
TEST_FEATS = 'default'
VISION_FEATS = 'default'
LANG_FEATS = 'default'
FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)

all_pairwise_accs = []
all_preds = []

if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
    all_train_stim_ids = []
    all_train_latents = []
    for subj in tqdm(SUBJECTS_ADDITIONAL_TEST):
        stim_ids, _ = get_stim_info(subj, SPLIT_TRAIN)
        
        latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN], TRAINING_MODE)
        latents = standardize_latents(latents)
    
        all_train_stim_ids.append(stim_ids)
        all_train_latents.append(latents[SPLIT_TRAIN])
    all_train_stim_ids = np.concatenate(all_train_stim_ids)
    all_train_latents = np.concatenate(all_train_latents)
    
    unique_stim_ids, indices = np.unique(all_train_stim_ids, return_index=True)
    unique_train_latents = all_train_latents[indices]
    print('candidate set size: ', len(unique_stim_ids))

for subj in SUBJECTS_ADDITIONAL_TEST:
    print(subj)

    stim_ids_test, _ = get_stim_info(subj, TEST_IMAGES)
    stim_ids_imagery, _ =  get_stim_info(subj, SPLIT_IMAGERY_WEAK)

    latents = get_latents_for_splits(subj, FEATS_CONFIG, [SPLIT_TRAIN, TEST_IMAGES, SPLIT_IMAGERY_WEAK], TRAINING_MODE)
    latents = standardize_latents(latents)

    predictions = load_predictions(BETAS_DIR, subj, TRAINING_MODE, FEATS_CONFIG, surface=SURFACE, mask=MASK)

    pred_latents_imagery = predictions[SPLIT_IMAGERY_WEAK]
    if len(RESTANDARDIZE_PREDS)>0:
        print(f'standardizing imagery predictions with {len(refs)} refs ({RESTANDARDIZE_PREDS})')
        refs = np.concatenate([predictions[split] for split in RESTANDARDIZE_PREDS])
        transform = StandardScaler().fit(refs)
        pred_latents_imagery = transform.transform(pred_latents_imagery)

    all_preds.append(pred_latents_imagery)

    # test_stim_ids_mod = results['stimulus_ids'][results['stimulus_types'] == IMAGE]
    # test_latents_mod = results['latents'][results['stimulus_types'] == IMAGE]

    if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
        # # account for case that sometimes both the image and the caption are part of the training set
        # unique_stim_ids, indices = np.unique(stim_ids, return_index=True)
        # unique_train_latents = train_latents[indices]
        unique_stim_ids, indices = np.unique(all_train_stim_ids, return_index=True)
        unique_train_latents = all_train_latents[indices]

        candidate_latents = np.concatenate((latents[SPLIT_IMAGERY_WEAK], latents[TEST_IMAGES], unique_train_latents))
        candidate_latent_ids = np.concatenate((stim_ids_imagery, stim_ids_test, unique_stim_ids))
        print('candidate set size: ', len(candidate_latent_ids))
    else:
        candidate_latents = np.array(latents[SPLIT_IMAGERY_WEAK])
        candidate_latent_ids = np.array(stim_ids_imagery)

    # acc = analysis_ranking(pred_latents_imagery, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, SPLIT_IMAGERY_WEAK, N_SAMPLES, N_NEIGHBORS, out_file_name=f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_{subj}.png", hspace=0.2)

    all_pairwise_accs.append(acc)

print(f'Mean pairwise acc: {np.mean(all_pairwise_accs):.2f}')

# Weak imagery decoding averaged over subjects

In [None]:
# all_preds_avgd = np.mean(all_preds, axis=0)
# acc = analysis_ranking(all_preds_avgd, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, SPLIT_IMAGERY_WEAK, N_SAMPLES, N_NEIGHBORS, out_file_name=f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_averaged.png", hspace=0.2, subsample='random')

## best samples:

In [None]:
N_SAMPLES = 5
all_preds_avgd = np.mean(all_preds, axis=0)
if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
    out_file_name = f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_averaged_large_candidate_set_best.png"
else:
    out_file_name = f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_averaged_best.png"
acc = analysis_ranking(all_preds_avgd, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, SPLIT_IMAGERY_WEAK, N_SAMPLES, N_NEIGHBORS, out_file_name=out_file_name, hspace=0.2, wspace=0.2, subsample='best')

## worst samples:

In [None]:
if WHOLE_TRAIN_AND_TEST_SET_AS_CANDIDATE_SET:
    out_file_name = f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_averaged_large_candidate_set_worst.png"
else:
    out_file_name = f"{IMAGERY}_weak_{TRAINING_MODE}_decoder_averaged_worst.png"
acc = analysis_ranking(all_preds_avgd, stim_ids_imagery, candidate_latents, candidate_latent_ids, subj, SPLIT_IMAGERY_WEAK, N_SAMPLES, N_NEIGHBORS, out_file_name=out_file_name, hspace=0.2, wspace=0.2, subsample='worst')

# TSNE for Latents

In [None]:
# train_fmri_betas_full, train_stim_ids, train_stim_types = get_fmri_data(
#     BETAS_DIR,
#     SUBJECT,
#     SPLIT_TRAIN,
#     TRAINING_MODE,
#     surface=SURFACE,
# )
# N_TRAIN_BETAS = 1000
# train_paths = train_paths[np.random.choice(range(len(train_paths)), size=N_TRAIN_BETAS, replace=False)]

# train_fmri_betas, test_fmri_betas, train_fmri_betas_standardized, test_fmri_betas_standardized = load_betas(train_paths, test_paths)

In [None]:
# def plot_betas(train_betas, test_betas, title, binwidth=3):
#     X = np.concatenate((train_betas.flatten(), test_betas.flatten()))
#     hue = ['train'] * train_betas.size + ['test'] * test_betas.size
#     plt.figure(figsize=(20, 10))
#     sns.histplot(x=X, hue=hue, binwidth=binwidth)
#     plt.title(title)

# print(np.nanmean(train_fmri_betas.mean(axis=0)))
# print(np.nanmean(test_fmri_betas.mean(axis=0)))
# print(np.nanmean(train_fmri_betas_standardized.mean(axis=0)))
# print(np.nanmean(test_fmri_betas_standardized.mean(axis=0)))

# plot_betas(train_fmri_betas, test_fmri_betas, title='unstandardized')
# plt.ylim(0, 10000000)
# plt.xlim(-25, 25)

# plot_betas(train_fmri_betas_standardized, test_fmri_betas, title='standardized', binwidth=0.3)
# plt.ylim(0, 4000000)
# plt.xlim(-3, 3)

In [None]:
# SUBJECT = 'sub-01'
# MODEL = "imagebind"
# SURFACE = True

# # TRAINING_MODE = "images"
# TRAINING_MODE = "agnostic"

# BETAS_SUFFIX = 'betas'
# BETAS_DIR = os.path.join(FMRI_DATA_DIR, BETAS_SUFFIX)

# # feats = 'avg'
# # test_feats = 'avg'
# FEATS = 'default'
# TEST_FEATS = 'default'
# # feats = 'lang'
# # test_feats = 'lang'
# # vision_feats = 'vision_features_cls'
# VISION_FEATS = 'default'
# # vision_feats = 'n_a'

# LANG_FEATS = 'default'
# # lang_feats = 'lang_features_cls'
# # lang_feats = 'lang_features_mean'

# FEATS_CONFIG = LatentFeatsConfig(MODEL, FEATS, TEST_FEATS, VISION_FEATS, LANG_FEATS)


# test_fmri_betas, test_stim_ids, test_stim_types = get_fmri_data(
#     BETAS_DIR,
#     SUBJECT,
#     SPLIT_TEST,
#     surface=SURFACE,
# )
# imagery_fmri_betas, imagery_stim_ids, imagery_stim_types = get_fmri_data(
#     BETAS_DIR,
#     SUBJECT,
#     SPLIT_IMAGERY,
#     surface=SURFACE,
# )

# train_latents = get_latent_features(FEATS_CONFIG, SUBJECT, SPLIT_TRAIN, mode=TRAINING_MODE)
# test_latents = get_latent_features(FEATS_CONFIG, SUBJECT, SPLIT_TEST)
# imagery_latents = get_latent_features(FEATS_CONFIG, SUBJECT, SPLIT_IMAGERY)

# train_latents, test_latents, imagery_latents = standardize_latents(
#     train_latents, test_latents, imagery_latents
# )

# results = load_predictions(BETAS_DIR, SUBJECT, TRAINING_MODE, FEATS_CONFIG, surface=SURFACE)

# # stim_ids, stim_types, train_latents, gray_matter_mask, results, train_paths, test_paths = load(BETAS_DIR, MODEL, SUBJECT, MODE, FEATS_CONFIG)


In [None]:
# def plot_latents_tsne(train_lat, test_lat, pred_lat, imagery_pred_lat, title, train_subset=1000):
#     train_latents_subset = train_latents[np.random.choice(range(len(train_lat)), size=train_subset, replace=False)]
    
#     tsne = TSNE(n_components=2, learning_rate='auto', verbose=1, n_jobs=10, n_iter=1000)
#     X_embedded = tsne.fit_transform(np.concatenate((train_latents_subset, test_lat, pred_lat, imagery_pred_lat)))
    
#     print(X_embedded.shape)
#     assert X_embedded.shape[1] == 2
#     hue = ['train'] * len(train_latents_subset) + ['test'] * len(test_lat) + ['predictions'] * len(pred_lat) + ['imagery_predictions'] * len(imagery_pred_lat)
#     # alphas = [0.3] * len(train_latents_subset) + [1] * len(test_lat) + [1] * len(preds)
    
#     plt.figure(figsize=(20, 12))
#     sns.scatterplot(
#         x = X_embedded[:, 0], y = X_embedded[:, 1],
#         hue = hue,
#         alpha = 0.8
#     )
#     plt.title(title)

# plot_latents_tsne(train_latents, results['latents'], results['predictions'], results['imagery_predictions'], title="not standardized")

# pred_latents_standardized = StandardScaler().fit_transform(results['predictions'])   
# imagery_pred_latents_standardized = StandardScaler().fit_transform(results['imagery_predictions'])
# plot_latents_tsne(train_latents, results['latents'], pred_latents_standardized, imagery_pred_latents_standardized, "standardized")

# TSNE for Betas

In [None]:
# def plot_betas_tsne(train_betas, test_betas, title, train_subset=None):
#     if train_subset is not None:
#         train_betas_subset = train_betas[np.random.choice(range(len(train_betas)), size=train_subset, replace=False)]
#     else:
#         train_betas_subset = train_betas
#     train_test = np.concatenate((train_betas_subset, test_betas))
#     tsne = TSNE(n_components=2, learning_rate='auto', verbose=1, n_jobs=10, n_iter=1000)
#     X_embedded = tsne.fit_transform(train_test)
    
#     print(X_embedded.shape)
#     assert X_embedded.shape[1] == 2
#     hue = ['train'] * len(train_betas_subset) + ['test'] * len(test_betas)
#     # alpha = [1] * len(test_betas) + [0.3] * len(train_betas_subset)
    
#     plt.figure(figsize=(20, 12))
#     sns.scatterplot(
#         x = X_embedded[:, 0], y = X_embedded[:, 1],
#         hue = hue,
#         # alpha = alpha
#     )
#     plt.title(title)

# plot_betas_tsne(train_fmri_betas, test_fmri_betas, title="not standardized")
# plot_betas_tsne(train_fmri_betas_standardized, test_fmri_betas_standardized, "standardized")
