In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import numpy as np
import h5py
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
from tqdm import tqdm
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
from PIL import Image
import json
import seaborn as sns
import pandas as pd
from matplotlib.patches import Patch
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
from torchvision.transforms import v2
from pathlib import Path

#local
from src.utils.helpers import FilterDataset
from src.utils.dataset import FMRIDataset
from src.utils.transforms import SelectROIs

In [None]:
def mds(utv, pos=None, n_jobs=1, n_components=2):
    #this function copied from NSD code
    """ pos = mds(utv)

    mds computes the multi-dimensional scaling solution on a 
    two dimensional plane, for a representational dissimilarity matrix.

    Args:

        utv (array): 1D upper (or lower) triangular part of an RDM

        pos (array, optional): set of 2D coordinates to initialise the MDS
                            with. Defaults to None.

        n_jobs (int, optional): number of cores to distribute to.
                            Defaults to 1.

    Returns:

        [array]: 2D aray of x and y coordinates.

    """

    rdm = squareform(utv)
    seed = np.random.RandomState(seed=3)
    mds = MDS(
        n_components=n_components,
        max_iter=100,
        random_state=seed,
        dissimilarity="precomputed",
        n_jobs=n_jobs
    )
    pos = mds.fit_transform(rdm, init=pos)

    return pos

In [None]:
#housekeeping
ncomponents = 2
perplexity=20
save_flag = False
trialselection='all' #'average'
root = os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"), "MOSAIC")
project_root = os.path.join(os.getenv("PROJECT_ROOT"))
print(f"root: {root}")
print(f"project root: {project_root}")
config = {"fmri": {"dataset_include": ['GOD', 'deeprecon'],
                   "subject_include": None,
                   "use_noiseceiling": True}}
n = 1 #'avg' #or n=1, only matters if 'use_noiseceiling' is True

In [None]:
#load a train/test json to identify filenames
    #load train and test jsons
with open(os.path.join(root, 'train_naturalistic.json'), 'r') as f:
    train_all = json.load(f)
with open(os.path.join(root, 'test_naturalistic.json'), 'r') as f:
    test_all = json.load(f)
dataset_preprocessing = FilterDataset(config['fmri']['subject_include'],
                                    config['fmri']['dataset_include'],
                                    config['fmri']['use_noiseceiling'])
#dataset_preprocessing = FilterDataset(['sub-01_deeprecon', 'sub-01_GOD'],
#                                    None,
#                                    config['fmri']['use_noiseceiling'])
train, subjectID_mapping_train = dataset_preprocessing.filter_splits(train_all)
test, subjectID_mapping_test = dataset_preprocessing.filter_splits(test_all)
all_subjects_dict = {**subjectID_mapping_train, **subjectID_mapping_test}
all_subjects = list(all_subjects_dict.keys())

shuffled_indices_train = np.random.permutation(len(train))
train = [train[i] for i in shuffled_indices_train]

shuffled_indices_test= np.random.permutation(len(test))
test = [test[i] for i in shuffled_indices_test]
n = len(test)
#n=200
#train_val = train_val[:n] #just for debugging
#test = test[:n]

In [None]:
eval_sets = ['train_naturalistic', 'test_naturalistic']
stimuli_list = {'train_naturalistic': [Path(list(stim.keys())[0]).stem for stim in train],
                'test_naturalistic': [Path(list(stim.keys())[0]).stem for stim in test]}

In [None]:
rois = ["GlasserGroup_1", "GlasserGroup_2", "GlasserGroup_3","GlasserGroup_4", "GlasserGroup_5"] #["LO1","LO2"] #["V1"]
ROI_selection = SelectROIs(selected_rois = rois)
fmri_tsfm = None #v2.Compose([ToTensorfMRI(dtype='float32')])
dataset = FMRIDataset(test, ROI_selection, config['fmri']['use_noiseceiling'], trialselection, fmri_transforms=fmri_tsfm)

In [None]:
cols = ['fmri','stimulus_filename', 'subject_id', 'dataset_id']
all_data = {eval_set: {col: [] for col in cols} for eval_set in eval_sets}
phase = 'test'
n='avg'
for subjectID in all_subjects: 
    sample = dataset.load_responses_block_hdf5(subjectID, verbose=True)
    stimuli = sample['stimulus_filename']
    for idx, stim in enumerate(stimuli):
        fmri = sample['fmri'][idx,:]
        subject_stim = f"{subjectID}_stimulus-{stim}"
        for eval_set in eval_sets:
            if subject_stim in stimuli_list[eval_set]:
                if config['fmri']['use_noiseceiling']:
                    noiseceiling = sample['noiseceiling'][f"{subjectID}_phase-{phase}_n-{n}"]
                    all_data[eval_set]['fmri'].append(fmri*noiseceiling)
                else:
                    all_data[eval_set]['fmri'].append(fmri)
                all_data[eval_set]['stimulus_filename'].append(stim)
                all_data[eval_set]['subject_id'].append(subjectID)
                all_data[eval_set]['dataset_id'].append(subjectID.split('_')[-1])
                continue #no need to check other eval sets. Note that some stimuli got removed from the filtering so are not part of any eval set.

In [None]:
train_data = np.vstack(all_data['train_naturalistic']['fmri'])
test_data = np.vstack(all_data['test_naturalistic']['fmri'])
print(np.vstack(train_data).shape)
print(np.vstack(test_data).shape)

In [None]:
#compute ROI RDM
rdm_flat = pdist(test_data, metric='correlation')
rdm = squareform(rdm_flat)
print(rdm.shape)
plt.imshow(rdm)

In [None]:

#compute MDS
Y_mds = mds(squareform(rdm), n_components=ncomponents)
tsne = TSNE(n_components=ncomponents, metric='precomputed', init=Y_mds, random_state=42, perplexity=perplexity, n_jobs=4)
X_tsne = tsne.fit_transform(rdm)
print("divergence: ", tsne.kl_divergence_)

In [None]:
save_flag = True
df = pd.DataFrame(all_data['test_naturalistic'])
X = np.vstack(df['fmri'].to_numpy())
print(X.shape)
df['X'] = X_tsne[:, 0]
df['Y'] = X_tsne[:, 1]

unique_subjects = df['subject_id'].unique()
unique_subjects = sorted(unique_subjects, key=lambda x: (x.split('_')[1], x.split('_')[0]))
palette = [sns.color_palette('tab10')[0], sns.color_palette('tab10')[2], sns.color_palette('tab10')[3], sns.color_palette('tab10')[4], sns.color_palette('tab10')[5], sns.color_palette('tab10')[1], sns.color_palette('tab10')[6], sns.color_palette('tab10')[7]]
#palette =  [sns.color_palette('tab10')[0], sns.color_palette('tab10')[1]] #sns.color_palette('tab10', len(unique_subjects))[::-1]
color_map = dict(zip(unique_subjects, palette))
# Assuming 'df' is your DataFrame, 'X_tsne' contains the t-SNE results, and you have 'stimulus_filename' in df.

# Set up the figure
fig, ax = plt.subplots(figsize=(10, 8))

# Get the center coordinates of the plot
x_center = (np.max(X_tsne[:, 0]) + np.min(X_tsne[:, 0])) / 2
y_center = (np.max(X_tsne[:, 1]) + np.min(X_tsne[:, 1])) / 2
stretch = np.floor(700 / (np.abs(np.max(X_tsne.ravel())) + np.abs(np.min(X_tsne.ravel()))))
print("stretch:", stretch)
scaler = 0.02 #adjust according to the resolution of the image and number of images you are plotting

# Track minimum and maximum x and y for setting axis limits later
min_x, max_x = np.inf, -np.inf
min_y, max_y = np.inf, -np.inf
print("looping over dataframe rows...")
for idx, row in df.iterrows():
    img_path = os.path.join(root, "stimuli", "stimuli_compressed_quality-95_size-224", f"{row['stimulus_filename']}.JPEG")
    label = row['subject_id']
    color = color_map[label]  # Get the color for the dataset_id

    x = row['X']
    y = row['Y']
    x_new = x_center + x*stretch
    y_new = y_center + y*stretch
    # plot middle frame 
    img = np.array(Image.open(os.path.join(img_path))).astype(np.float64) /255 #.astype('uint8')

    # Update min/max coordinates to accommodate the image extent
    min_x = min(min_x, x_new)
    max_x = max(max_x, x_new + scaler * img.shape[1])
    min_y = min(min_y, y_new)
    max_y = max(max_y, y_new + scaler * img.shape[0])

    # Add a border using plt.Rectangle with the subject_id's color
    rect = plt.Rectangle((x_new, y_new), scaler * img.shape[1], scaler * img.shape[0],
                         linewidth=3, edgecolor=color, facecolor='none', zorder=1)
    ax.add_patch(rect)

    ax.imshow(img, extent=[x_new, x_new+scaler*img.shape[1], y_new, y_new+scaler*img.shape[0]], zorder=2)

padding = scaler * 80  # Adjust based on image size
ax.set_xlim(min_x - padding, max_x + padding)
ax.set_ylim(min_y - padding, max_y + padding)
ax.set_aspect('equal')
ax.set_axis_off()

# Create legend patches for each dataset_id
legend_patches = [Patch(color=color_map[subject], label=subject) for subject in unique_subjects]
# Add the legend outside the plot
plt.legend(handles=legend_patches, title="Subject ID", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# Adjust layout to make space for the legend
plt.subplots_adjust(right=0.8)

if save_flag:
    print("saving plot...")
    plot_fname = f"ROIs-{('-').join(rois)}_subjects-{('-').join(dataset_preprocessing.subjects_to_include)}_usenoiseceiling-{config['fmri']['use_noiseceiling']}_n-{n}_trialselection-{trialselection}_tsne.png"
    save_root = os.path.join(project_root, "src", "fmriDatasetPreparation", "visualizations", "fmri_embeddings_tsne")
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    plt.savefig(os.path.join(save_root, plot_fname), dpi=300)
    #plt.show()
    #plt.clf()

In [None]:
df = pd.DataFrame(all_data['test_naturalistic'])
X = np.vstack(df['fmri'].to_numpy())
print(X.shape)
df['X'] = X_tsne[:, 0]
df['Y'] = X_tsne[:, 1]

# Plotting
plt.figure(figsize=(10, 8))
sns.scatterplot(data=df, x='X', y='Y', hue='dataset_id', palette='tab10', s=10, alpha=0.6)
#
plt.title('TSNE Projection')
plt.show()

In [None]:
#save as json for viewing in webpage
df.drop('fmri', axis=1, inplace=True)
json_filename = f"ROIs-{('-').join(rois)}_subjects-{('-').join(dataset_preprocessing.subjects_to_include)}_usenoiseceiling-{config['fmri']['use_noiseceiling']}_n-{n}_tsne.json"
save_root = os.path.join("/data/vision/oliva/blahner/projects/BrainEmbedder/data/tsne")
if not os.path.exists(save_root):
    os.makedirs(save_root)
df.to_json(os.path.join(project_root, "assets", json_filename), orient='records', lines=False)
df.to_json(os.path.join(save_root, json_filename), orient='records', lines=False)