In [2]:
import os

import nibabel
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from glob import glob
import pickle

from PIL import ImageColor, Image
import matplotlib.colors

from utils import NN_FEATURES_DIR, RESULTS_DIR, SUBJECTS, NUM_TEST_STIMULI, FMRI_SURFACE_LEVEL_DIR, STIM_INFO_PATH, COCO_IMAGES_DIR, STIMULI_IDS_PATH
from analyses.ridge_regression_decoding import NUM_CV_SPLITS, RIDGE_DECODER_OUT_DIR, calc_rsa, calc_rsa_images, calc_rsa_captions, get_fmri_data, pairwise_accuracy, \
ACC_MODALITY_AGNOSTIC, ACC_CAPTIONS, ACC_IMAGES, ACC_CROSS_IMAGES_TO_CAPTIONS, ACC_CROSS_CAPTIONS_TO_IMAGES, ACC_IMAGERY, ACC_IMAGERY_WHOLE_TEST, get_default_features, get_default_vision_features, get_default_lang_features, Standardize, IMAGE, CAPTION, get_distance_matrix, dist_mat_to_pairwise_acc, get_fmri_data_paths, get_nn_latent_data

from notebook_utils import add_avg_subject, create_result_graph, plot_metric_catplot, plot_metric, load_results_data, ACC_MEAN, ACC_CROSS_MEAN, PALETTE_BLACK_ONLY, METRICS_ERROR_ANALYSIS, get_data_default_feats, METRICS_BASE

from feature_extraction.feat_extraction_utils import CoCoDataset 

# Nearest Neighbors of test images

In [3]:
from utils import FMRI_DATA_DIR

SUBJECT = 'sub-01'
MODEL = "imagebind"
STIM_TYPE = 'image'
BETAS_DIR = os.path.join(FMRI_DATA_DIR, 'betas_new')
train_paths, stim_ids, stim_types = get_fmri_data_paths(BETAS_DIR, SUBJECT, "train")

latents, _ = get_nn_latent_data(MODEL, 'avg',
                                    'vision_features_cls',
                                    'lang_features_cls',
                                    stim_ids,
                                    stim_types,
                                    SUBJECT,
                                    "train"
                                    )

train_paths = np.array(train_paths)[stim_types == STIM_TYPE]

In [4]:
from tqdm import trange
from preprocessing.create_gray_matter_masks import get_graymatter_mask_path
import nibabel
test_fmri = []

gray_matter_mask_path = get_graymatter_mask_path(SUBJECT)
gray_matter_mask_img = nibabel.load(gray_matter_mask_path)
gray_matter_mask_data = gray_matter_mask_img.get_fdata()
gray_matter_mask = gray_matter_mask_data == 1
print(f"Gray matter mask size: {gray_matter_mask.sum()}")

fmri_train_betas = []
for idx in trange(10, desc="loading fmri data"):
    sample = nibabel.load(train_paths[idx])
    sample = sample.get_fdata()
    sample = sample[gray_matter_mask].astype('float32').reshape(-1)
    fmri_train_betas.append(sample)

fmri_train_betas = np.array(fmri_train_betas)


Gray matter mask size: 162649


loading fmri data: 100%|██████████| 10/10 [01:47<00:00, 10.74s/it]


array([[-0.35332873,  1.1735638 ,  0.4762943 , ..., -0.17560871,
         0.88654995,  2.4132216 ],
       [-2.9066963 ,  2.524966  ,  5.192939  , ..., -1.594213  ,
        -1.8071494 , -3.6754138 ],
       [-3.338414  , -2.1419594 , -0.20552808, ..., -0.20531629,
         1.6601925 , -0.460278  ],
       ...,
       [-1.4724239 , -0.7705569 ,  0.58517325, ...,  3.7909627 ,
        -1.8592083 , -0.3197928 ],
       [-6.6351905 , -2.5074031 , -1.6858445 , ...,  2.8640425 ,
        -1.885861  , -3.570221  ],
       [-4.3158445 , -3.3233457 , -3.5394206 , ..., -1.3201085 ,
         1.633338  ,  2.0146437 ]], dtype=float32)

In [5]:
test_paths, test_stim_ids, test_stim_types = get_fmri_data_paths(BETAS_DIR, SUBJECT, "test")
test_paths = np.array(test_paths)[test_stim_types == STIM_TYPE]
test_paths.shape

(70,)

In [6]:
from tqdm import trange
from preprocessing.create_gray_matter_masks import get_graymatter_mask_path
import nibabel
test_fmri = []

gray_matter_mask_path = get_graymatter_mask_path(SUBJECT)
gray_matter_mask_img = nibabel.load(gray_matter_mask_path)
gray_matter_mask_data = gray_matter_mask_img.get_fdata()
gray_matter_mask = gray_matter_mask_data == 1
print(f"Gray matter mask size: {gray_matter_mask.sum()}")

fmri_betas = []
for idx in trange(10, desc="loading fmri data"):
    sample = nibabel.load(test_paths[idx])
    sample = sample.get_fdata()
    sample = sample[gray_matter_mask].astype('float32').reshape(-1)
    fmri_betas.append(sample)

fmri_betas = np.array(fmri_betas)


Gray matter mask size: 162649


loading fmri data: 100%|██████████| 10/10 [01:39<00:00,  9.97s/it]


In [7]:
test_stim_ids[0]
test_stim_types[0]

'caption'

In [15]:
fmri_train_betas[:, 20000:20020].max(axis=0).mean()
fmri_train_betas.max(axis=0).mean()


7.950005

In [16]:
fmri_betas[:, 20000:20020].max(axis=0).mean()
fmri_betas.max(axis=0).mean()


2.2271032

In [None]:

# targets = pickle.load(open('/home/mitja/Downloads/targets.p', 'rb'))[test_stim_types == STIM_TYPE]
# preds = pickle.load(open('/home/mitja/Downloads/preds.p', 'rb'))[test_stim_types == STIM_TYPE]
# print(targets.shape)
# print(preds.shape)


In [19]:
results = pickle.load(open('/home/mitja/data/multimodal_decoding/whole_brain_decoding/train_image/sub-01/imagebind_avg_test_avg_vision_features_cls_lang_features_cls/results.p', 'rb'))
preds = results['predictions'][test_stim_types == STIM_TYPE]
targets = results['latents'][test_stim_types == STIM_TYPE]

print(targets.shape)
print(preds.shape)



(70, 1024)
(70, 1024)


In [34]:
preds.std(axis=0)
targets.std(axis=0)
preds

array([[ 0.1731828 , -0.04481555, -0.4342584 , ...,  0.10285233,
         0.08117852, -0.65347534],
       [-0.17113288, -0.25586656,  0.54344887, ...,  0.10178385,
        -0.15762343, -0.03463428],
       [-0.26914987,  0.01100233, -0.17898676, ..., -0.60147053,
         1.1573492 , -0.29404137],
       ...,
       [-0.25330424,  0.16691068, -0.3516128 , ..., -0.52294624,
         0.36219645, -0.24241231],
       [-0.28799325, -0.9923302 ,  0.57216525, ..., -0.7669344 ,
         0.48000363,  0.48541442],
       [ 0.93688256,  0.4988965 , -1.5358694 , ...,  0.8586009 ,
         0.13756365, -0.7165827 ]], dtype=float32)

In [None]:
test_preds = pickle.load(open('/home/mitja/Downloads/test_preds.p', 'rb'))
test_preds = test_preds[test_stim_types == STIM_TYPE]
test_preds

In [22]:
test_stim_ids = test_stim_ids[test_stim_types == STIM_TYPE]
# test_stim_types = test_stim_types[test_stim_types == STIM_TYPE]

In [None]:
train_preds = pickle.load(open('/home/mitja/Downloads/train_preds.p', 'rb'))
train_preds

In [None]:
from analyses.ridge_regression_decoding import load_latents_transform

nn_latent_transform = load_latents_transform(
            SUBJECT, MODEL, 'avg', 'vision_features_cls',
                                    'lang_features_cls', 'train_image'
        )
nn_latent_transform['image']

In [None]:
train_preds_transform = Standardize(train_preds.mean(axis=0), train_preds.std(axis=0))
train_preds_transform.std

In [None]:
concat_preds = np.concatenate((train_preds, preds), axis=0)
print(concat_preds.shape)
concat_preds_transform = Standardize(concat_preds.mean(axis=0), concat_preds.std(axis=0))
train_preds_transform.std

In [None]:
concat_preds_transform(preds).std(axis=0)


In [None]:
preds_transformed = concat_preds_transform(preds)
print(preds_transformed.shape)
preds_transformed.std(axis=0)

In [None]:
nn_latent_transform['image'](preds)


In [25]:
def display_stimuli(coco_ids):
    # print(coco_ids)
    for coco_id in coco_ids:
        print(coco_ds.captions[coco_id], end="\n")

coco_ds = CoCoDataset(COCO_IMAGES_DIR, STIM_INFO_PATH, STIMULI_IDS_PATH, 'caption')

In [None]:
display_stimuli([139])

In [26]:
test_stim_ids.shape

(70,)

## With predictions averaged over all subjects:

In [36]:
NUM_SAMPLES = 5
N_NEIGHBORS = 5
# training_mode = 'modality-agnostic'
training_mode = 'images'



print(f"fMRI stimulus modality: {STIM_TYPE}")
# for subject in SUBJECTS:
    # df_subj = df[df.subject == subject]
    # 
    # test_predictions = df_subj[df_subj.metric == "predictions"].value.item()
    # test_latents = df_subj[df_subj.metric == "latents"].value.item()
    # test_stimulus_ids = df_subj[df_subj.metric == "stimulus_ids"].value.item()
    # stimulus_types = df_subj[df_subj.metric == "stimulus_types"].value.item()
    # 
    # test_latents = test_latents[stimulus_types == stimulus_type]
    # test_stimulus_ids = test_stimulus_ids[stimulus_types == stimulus_type]
    # test_predictions = test_predictions[stimulus_types == stimulus_type]
    
# candidate_set_latents = targets
# candidate_set_latents_ids = test_stim_ids

test_predictions = preds
# preds_standardize = Standardize(preds.mean(axis=0), preds.std(axis=0))
# test_predictions = preds_standardize(test_predictions)

# all_test_predictions.append(test_predictions)

# test_predictions_averaged = np.mean(all_test_predictions, axis=0)

candidate_set_latents = np.concatenate((latents, targets))
candidate_set_latents_ids = np.concatenate((stim_ids, test_stim_ids))
    
dist_mat = get_distance_matrix(test_predictions, candidate_set_latents)
acc = dist_mat_to_pairwise_acc(dist_mat)
print(f"Pairwise acc: {acc:.2f}")
np.random.seed(7)
sampled_ids = np.random.choice(range(len(test_stim_ids)), NUM_SAMPLES, replace=False)
test_stimulus_ids = test_stim_ids[sampled_ids]
dist_mat = dist_mat[sampled_ids]

for test_stimulus_id, nneighbors_row in zip(test_stimulus_ids, dist_mat):
    print(f"test stimulus: {test_stimulus_id}")
    display_stimuli([test_stimulus_id])
    # if STIM_TYPE == CAPTION:
    #     print(coco_ds.get_stimuli_by_coco_id(test_stimulus_id)[1])
    # else:
    #     img = coco_ds.get_img_by_coco_id(test_stimulus_id)
    #     new_width  = 400
    #     new_height = round(new_width * img.height / img.width)
    #     display(img.resize((new_width, new_height)))
    target_location = np.argwhere(candidate_set_latents_ids == test_stimulus_id)[0][0]
    nneighbors_ids = candidate_set_latents_ids[np.argsort(nneighbors_row)]
    rank = np.argwhere(nneighbors_ids == test_stimulus_id)[0][0]
    # all_ranks.append(rank)
    # display(f"distance to target: {nneighbors_row[target_location]:.3f} | rank: {rank} of {len(nneighbors_row)}")
    # display(f"nearest neighbors distances: {np.sort(nneighbors_row)[:10]}")
    print("nearest neighbors: ")

    display_stimuli(nneighbors_ids[:N_NEIGHBORS])
    print("\n")
    
# print(f"mean rank: {np.mean(all_ranks)}")


fMRI stimulus modality: image
Pairwise acc: 0.52
test stimulus: 163240
A train coming up the tracks through trees
nearest neighbors: 
a train that is parked by many others
A city train stopped at the train station
a number of trains on tracks near a building
a train at a standstill next to a platform
A train sitting inside of a train station


test stimulus: 16764
A small dog standing inside a car
nearest neighbors: 
A cat sitting on a toilet in a room
A cat sitting on top of a vase
A white toilet sitting in a corner of a room
A person sitting on a chair with a cake
a room with a toilet and a dustbin


test stimulus: 450719
A giraffe drinking water from a man-made pond
nearest neighbors: 
A person grinding down a rail with skis
a person riding skis on a rail
A dog that is climbing onto a bowl
A snowboarder going on a rail in the snow
a cat sitting in a bathroom sink


test stimulus: 195406
The teddy bear has a big bright red bow
nearest neighbors: 
A very cute cat laying in a bowl
a ca

## Per-subject:

In [None]:
NUM_SAMPLES = 5
N_NEIGHBORS = 5

NUM_SUBJECTS = 2

training_mode = 'modality-agnostic'
# training_mode = 'images'

df = data_default_feats.copy()

df = df[df.model == MODEL]
df = df[df.training_mode == training_mode]
df = df[df.surface == False]

assert len(df[df.metric == "predictions"]) == len(SUBJECTS)

for subject in SUBJECTS[:NUM_SUBJECTS]:
    print(f"\n\nSubject: {subject}")

    for stimulus_type in [IMAGE, CAPTION]:
        all_ranks = []
        print(f"fMRI stimulus modality: {stimulus_type}")
        df_subj = df[df.subject == subject]
        
        test_predictions = df_subj[df_subj.metric == "predictions"].value.item()
        test_latents = df_subj[df_subj.metric == "latents"].value.item()
        test_stimulus_ids = df_subj[df_subj.metric == "stimulus_ids"].value.item()
        stimulus_types = df_subj[df_subj.metric == "stimulus_types"].value.item()
    
        test_latents_in_mod = test_latents[stimulus_types == stimulus_type]
        test_stimulus_ids_in_mod = test_stimulus_ids[stimulus_types == stimulus_type]
        test_predictions_in_mod = test_predictions[stimulus_types == stimulus_type]
       
        candidate_set_latents = np.concatenate((train_latents[subject], test_latents_in_mod))
        candidate_set_latents_ids = np.concatenate((train_stim_ids[subject], test_stimulus_ids_in_mod))
    
        preds_standardize = Standardize(test_predictions_in_mod.mean(axis=0), test_predictions_in_mod.std(axis=0))
        test_predictions_in_mod = preds_standardize(test_predictions_in_mod)
    
        # targets_standardize = Standardize(candidate_set_latents.mean(axis=0), candidate_set_latents.std(axis=0))
        # candidate_set_latents = targets_standardize(candidate_set_latents)
    
        dist_mat = get_distance_matrix(test_predictions_in_mod, candidate_set_latents) #, metric="euclidean"
        acc = dist_mat_to_pairwise_acc(dist_mat)
        print(f"Pairwise acc: {acc:.2f}")
        np.random.seed(7)
        sampled_ids = np.random.choice(range(len(test_stimulus_ids_in_mod)), NUM_SAMPLES, replace=False)
        test_stimulus_ids_in_mod = test_stimulus_ids_in_mod[sampled_ids]
        dist_mat = dist_mat[sampled_ids]
        
        for test_stimulus_id, nneighbors_row in zip(test_stimulus_ids_in_mod, dist_mat):
            print(f"test stimulus: {test_stimulus_id}")
            if stimulus_type == CAPTION:
                print(coco_ds.get_stimuli_by_coco_id(test_stimulus_id)[1])
            else:
                display(coco_ds.get_img_by_coco_id(test_stimulus_id))
            target_location = np.argwhere(candidate_set_latents_ids == test_stimulus_id)[0][0]
            nneighbors_ids = candidate_set_latents_ids[np.argsort(nneighbors_row)]
            rank = np.argwhere(nneighbors_ids == test_stimulus_id)[0][0]
            all_ranks.append(rank)
            # display(f"distance to target: {nneighbors_row[target_location]:.3f} | rank: {rank} of {len(nneighbors_row)}")
            # display(f"nearest neighbors distances: {np.sort(nneighbors_row)[:10]}")
            display_stimuli(nneighbors_ids[:N_NEIGHBORS])
            print("\n")
        
        # print(f"mean rank: {np.mean(all_ranks)}")
        print("\n\n")
