In [None]:
import numpy as np
import pickle
import os
from scipy.spatial.distance import squareform
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import MDS
import pandas as pd
from PIL import Image
from scipy.spatial.distance import pdist
from tqdm import tqdm
import glob as glob
import argparse
from scipy.stats import pearsonr
import hcp_utils as hcp
from pathlib import Path

In [None]:
def mds(utv, pos=None, n_jobs=1, n_components=2):
    #this function copied from NSD code
    """ pos = mds(utv)

    mds computes the multi-dimensional scaling solution on a 
    two dimensional plane, for a representational dissimilarity matrix.

    Args:

        utv (array): 1D upper (or lower) triangular part of an RDM

        pos (array, optional): set of 2D coordinates to initialise the MDS
                            with. Defaults to None.

        n_jobs (int, optional): number of cores to distribute to.
                            Defaults to 1.

    Returns:

        [array]: 2D aray of x and y coordinates.

    """

    rdm = squareform(utv)
    seed = np.random.RandomState(seed=3)
    mds = MDS(
        n_components=n_components,
        max_iter=100,
        random_state=seed,
        dissimilarity="precomputed",
        n_jobs=n_jobs
    )
    pos = mds.fit_transform(rdm, init=pos)

    return pos

In [None]:
def create_rdm(X: np.ndarray, method: str = 'pearson'):
    #fills in complete matrix
    #X is shape numstim x numvertices (i.e., features)
    numstim = X.shape[0]
    rdm = np.zeros((numstim,numstim))
    for i in range(numstim):
        for j in range(numstim):
            if method == 'pearson':
                rdm[i,j] = 1-pearsonr(X[i,:], X[j,:])[0]
    
    return rdm

In [None]:
class arguments:
    def __init__(self) -> None:
        self.subject = '1' 
        self.dataset_root = "/data/vision/oliva/scratch/datasets/NaturalObjectDataset"
        self.project_root = "/data/vision/oliva/blahner/projects/SheenBrain/fmriDatasetPreparation/NaturalObjectDataset/validation"
        self.task = 'imagenet'
        self.image_plot = True #plot video frames on top tsne results
        self.n_components = 2
        self.perplexity = 30
        self.verbose = True

args = arguments()
subject = f"sub-{int(args.subject):02}"
session = "ses-imagenet01"
hcp_roilist = list(hcp.mmp.labels.values())
ROI = "L_MT" # ['LO1', 'LO2', 'PHA1', 'PHA2', 'MT']

save_root = os.path.join(args.project_root, "output_tsne")
if not os.path.exists(save_root):
    os.makedirs(save_root)

In [None]:
#load fMRI data from that subject and ROI
fmri_data_wb = np.load(os.path.join(args.dataset_root, "derivatives", "GLM", subject, session, f"{subject}_organized_betas.npy")) #shape numstim, numreps, numvertices
roi_indices = np.where(hcp.mmp.map_all == hcp_roilist.index(ROI))[0] 
fmri_data_roi = np.mean(fmri_data_wb[:,:,roi_indices],axis=1) #average over reps

with open(os.path.join(args.dataset_root, "derivatives", "GLM", subject, session, f"{subject}_{session}_task-{args.task}_conditionOrderDM.pkl"), 'rb') as f:
    events_run, ses_conds = pickle.load(f)

In [None]:
#compute ROI RDM
rdm_flat = pdist(fmri_data_roi, metric='correlation')
rdm = squareform(rdm_flat)
print(rdm.shape)
plt.imshow(rdm)

In [None]:
#compute MDS
Y_mds = mds(squareform(rdm), n_components=args.n_components)

tsne = TSNE(n_components=args.n_components, metric='precomputed', init=Y_mds, random_state=42, perplexity=args.perplexity, n_jobs=1)
X_tsne = tsne.fit_transform(rdm)
print("divergence: ", tsne.kl_divergence_)

In [None]:
#organize data into a pandas dataframe for much easier and more flexible plotting
cols = ['stimIDX'] #option to add columns for metadata labels, like face or no face, indoor vs outdoor, people vs no people
cols.extend(["tsne{}".format(n+1) for n in range(args.n_components)])
data_plot = {c: [] for c in cols}
for count, i in enumerate(ses_conds):
    stimIDX = i #f"stimIDX{count+1:05}"
    for r in range(args.n_components):
        data_plot['tsne{}'.format(str(r+1))].append(X_tsne[count, r])
    data_plot['stimIDX'].append(stimIDX)

df = pd.DataFrame(data_plot)
# no hue
sns.scatterplot(x = "tsne1", y = "tsne2", data = df)
plt.savefig(os.path.join(save_root, f"{subject}_ROI-{ROI}_tsne_dots.png"))
plt.show()
plt.clf()

In [None]:
#not setup for 3D plots
fig, ax = plt.subplots()
# Get the center coordinates of the plot
x_center = 250 #ax.get_xlim()[1] / 2
y_center = 250 #ax.get_ylim()[1] / 2
stretch = np.floor(500 / (np.abs(np.max(X_tsne.ravel())) + np.abs(np.min(X_tsne.ravel()))))
print("stretch:", stretch)
scaler = 0.04
for i, pat in enumerate(X_tsne):
    image_filename = ses_conds[i]
    dset, nfold, fname = image_filename.split('/') 

    image_path = os.path.join(args.dataset_root, "Nifti", "stimuli", args.task, nfold, fname)
    total_frames = len(list(Path(image_path).glob('*')))

    x, y = pat
    x_new = x_center + x*stretch
    y_new = y_center + y*stretch
    # plot middle frame 
    stimIDX = f"stimIDX{i+1:05}"
    img = np.array(Image.open(os.path.join(image_path))).astype(np.float64) /255 #.astype('uint8')
    ax.imshow(img, extent=[x_new, x_new+scaler*img.shape[1], y_new, y_new+scaler*img.shape[0]])
ax.set_xlim(0, 500)
ax.set_ylim(0, 500)
ax.set_aspect('equal')
ax.set_axis_off()
plt.savefig(os.path.join(save_root, f"{subject}_ROI-{ROI}_tsne_plotframes.png"), dpi=300)
plt.show()
plt.clf()