In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import time
import random
import json
import gc

import PIL
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import h5py
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage, stats
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torchvision import transforms
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 
from scipy.stats import sem

dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.data.natural_scenes import NaturalScenesDataset
from research.experiments.nsd.nsd_utils import tsne_image_plot
from research.metrics.metrics import cosine_distance, top_knn_test, r2_score, pearsonr
from pipeline.utils import get_data_iterator, DisablePrints, read_patch
from pipeline.compact_json_encoder import CompactJSONEncoder

In [None]:
nsd_path = Path('D:\\Datasets\\NSD\\')
nsd = NaturalScenesDataset(nsd_path, coco_path='X:\\Datasets\\COCO')
stimuli_path = nsd_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimuli_path, 'r')['imgBrick']

In [None]:
model_name = 'ViT-B=32' #'clip-vit-large-patch14'
group_name = 'group-22'

subjects = [f'subj0{i}' for i in range(1, 9)]
embedding_name = 'embedding'
fold_name = 'val'

embeddings = h5py.File(nsd_path / f'derivatives/decoded_features/{model_name}/{group_name}.hdf5', 'r')

results_path = nsd_path / f'derivatives/figures/decoding/{model_name}/{group_name}/{fold_name}/{embedding_name}/'
results_path.mkdir(exist_ok=True, parents=True)

Y_full = h5py.File(nsd_path / f'derivatives/stimulus_embeddings/{model_name}.hdf5', 'r')[embedding_name][:]

In [None]:
# Load decoder data

from research.models.fmri_decoders import Decoder

num_voxels = None

results = {}
folds = {}
for fold in ('val', 'test'):
    fold_data = {
        'X_all': [],
        'Y_all': [],
        'Y_pred_all': [],
        'stimulus_ids_all': [],
    }
    folds[fold] = fold_data
models_all = []
state_dicts_all = []
indices_all = []

load_X = True

for i, subject in enumerate(subjects):
    print(subject)
    
    subject_embeddings = embeddings[f'{subject}/{embedding_name}']
    train_mask, val_mask, test_mask = nsd.get_split(subject, 'split-01')

    if load_X:
        config = dict(subject_embeddings.attrs)

        model_params = {k: config[k] for k in ('layer_sizes', 'dropout_p') if k in config}
        model = Decoder(**model_params)
        model = model.eval()
        state_dict = {k: torch.from_numpy(v[:]) for k, v in subject_embeddings['model'].items()}
        state_dicts_all.append(state_dict)
        model.load_state_dict({k: v.clone() for k, v in state_dict.items()})
        models_all.append(model)
        
        betas_params = {
            k: config[k] 
            for k in (
                'subject_name', 'voxel_selection_path', 
                'voxel_selection_key', 'num_voxels', 'return_volume_indices', 'threshold'
            )
        }
        if betas_params['threshold'] is not None:
            betas_params['num_voxels'] = None
            betas_params['return_tensor_dataset'] = False
        betas, betas_indices = nsd.load_betas(**betas_params)
        folds['val']['X_all'].append(betas[val_mask])
        folds['test']['X_all'].append(betas[test_mask])
        indices_all.append(betas_indices)

    stimulus_params = dict(
        subject_name=subject,
        stimulus_path=f'derivatives/stimulus_embeddings/{model_name}.hdf5',
        stimulus_key=embedding_name,
        delay_loading=False,
        return_tensor_dataset=False,
        return_stimulus_ids=True,
    )
    stimulus, stimulus_ids = nsd.load_stimulus(**stimulus_params)
    for fold, mask in [('val', val_mask), ('test', test_mask)]:
        
        folds[fold]['stimulus_ids_all'].append(stimulus_ids[mask])
    
        Y = stimulus[mask].astype(np.float32)
        Y = Y.reshape(Y.shape[0], -1)
        folds[fold]['Y_all'].append(Y)
    
        Y_pred = subject_embeddings[f'{fold}/Y_pred'][:]
        Y_pred = Y_pred / np.linalg.norm(Y_pred, axis=1)[:, None]
        folds[fold]["Y_pred_all"].append(Y_pred)
        
locals().update(folds[fold_name])

# DBSCAN Clustering

In [None]:
# Load a group of decoders

decoding_models = h5py.File(nsd_path / f'derivatives/decoded_features/ViT-B=32/{group_name}.hdf5')
W_decoding_all = []
volume_indices_decoding_all = []
for subject_id, subject_name in enumerate(subjects):
    subject = decoding_models[f'{subject_name}/embedding/']
    W_decoding_all.append(subject[f'model/layers.0.weight'][:])
    volume_indices_decoding_all.append(subject[f'volume_indices'][:])

In [None]:
# Alteratively load a group of decoders where there were multiple runs per participant, average the weights across all runs

group_name = 'group-22_reruns-2'

decoding_models = h5py.File(nsd_path / f'derivatives/decoded_features/ViT-B=32/{group_name}.hdf5')
num_runs = 50

W_decoding_all = []
W_decoding_reruns_all = []
volume_indices_decoding_all = []
for subject_id, subject_name in enumerate(subjects):
    print(subject_name)
    W_decoding_subject = []
    for run_id in range(num_runs):
        print(run_id)
        W_decoding_subject.append(decoding_models[f'{subject_name}/embedding/run_{run_id}/model/layers.0.weight'][:])
    W_decoding_reruns_all.append(np.stack(W_decoding_subject))
    W_decoding_all.append(np.stack(W_decoding_subject).mean(axis=0))
    volume_indices_decoding_all.append(decoding_models[f'{subject_name}/embedding/run_{run_id}/volume_indices'][:])

In [None]:
# Compute and cache nearest neighbor queries for the voxel-wise parameter vectors

tag, W = (f'linear_decoding__{group_name}', [W_subj.T for W_subj in W_decoding_all])
num_models = 1

W = [W_subj / np.linalg.norm(W_subj, axis=1, keepdims=True) for W_subj in W]
for W_subj in W:
    W_subj[np.isnan(W_subj)] = 0.
nn_all = [NearestNeighbors(radius=1.0, metric='cosine').fit(W_subj) for W_subj in W]
cluster_name = 'density'
print(tag)

out_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}'
out_path.mkdir(exist_ok=True, parents=True)

eps = 0.7 # max neighborhood size for queries
print('computing neighbors')
nn_results = np.zeros((8, 8), dtype=object)
for subject_i in range(8):
    for subject_j in range(8):
        print(subject_i, subject_j)
        nn_results[subject_i, subject_j] = nn_all[subject_i].radius_neighbors(W[subject_j], radius=eps, sort_results=True)

In [None]:
# Optionally save these queries to the disk

np.save(out_path / f'nn_results__num_models-{num_models}_v1.npy', nn_results)

In [None]:
# Load saved queries from the disk

nn_results = np.load(out_path / f'nn_results__num_models-{num_models}_v1.npy', allow_pickle=True)

In [None]:
# DBSCAN code

from collections import deque

UNEXPLORED = -2
OUTLIER = -1

     
def cross_participant_clustering(cluster_ids, cluster_id, subject_i, voxel_id, eps, min_neighbors):
    core_samples = []
    fringe = deque([(subject_i, voxel_id)])
    fringe_hash = set([(subject_i, voxel_id)])

    while len(fringe) > 0:
        check_subject, check_voxel = fringe.popleft()
        fringe_hash.remove((check_subject, check_voxel))
        
        nn_dists = [nn_results[subject_j, check_subject][0][check_voxel] for subject_j in range(8)]
        nn_ids = [nn_results[subject_j, check_subject][1][check_voxel] for subject_j in range(8)]

        nn_masks = [dists < eps for dists in nn_dists]
        nn_dists = [dists[mask] for dists, mask in zip(nn_dists, nn_masks)]
        nn_ids = [ids[mask] for ids, mask in zip(nn_ids, nn_masks)]

        # number of subjects that have at least one
        neighbor_counts = [d.shape[0] for d in nn_dists]
        
        neighbor_counts[check_subject] = 0
        min_pts = np.sum(np.array(neighbor_counts) != 0)

        # its a core sample
        if min_pts >= min_neighbors:

            # add to core samples list and cluster array
            core_samples.append((check_subject, check_voxel))
            cluster_ids[check_subject][check_voxel] = cluster_id

            # add neighbors to the fringe
            for subject_j, neighbor_ids in enumerate(nn_ids):
                for neighbor_id in neighbor_ids:

                    # dont add the point we are checking
                    if subject_j == check_subject and neighbor_id == check_voxel:
                        continue
                    # dont add points already in fringe
                    if (subject_j, neighbor_id) in fringe_hash:
                        continue
                    # dont add points already in this cluster
                    if cluster_ids[subject_j][neighbor_id] == cluster_id:
                        continue
                        
                    fringe.append((subject_j, neighbor_id))
                    fringe_hash.add((subject_j, neighbor_id))

        # not a core sample
        else:

            # mark it as an outlier if its not in the neighborhood of a core sample
            if len(core_samples) == 0:
                cluster_ids[check_subject][check_voxel] = OUTLIER
            # otherwise its part of the cluster
            else:
                cluster_ids[check_subject][check_voxel] = cluster_id
    return core_samples

def distance_expansion(core_samples_all, cluster_ids, eps):
    expanded_clusters_all = []
    for subject_id in range(8):
        expanded_clusters = np.full((len(core_samples_all), W[subject_id].shape[0]), UNEXPLORED)
        for c_id in range(len(core_samples_all)):
            cluster_mask = cluster_ids[subject_id] == c_id
            if not np.any(cluster_mask):
                expanded_clusters[c_id] = 0
                continue
            
            voxel_ids = np.where(cluster_mask)[0]
            nn_dists = np.concatenate([nn_results[subject_id, subject_id][0][voxel_id] for voxel_id in voxel_ids])
            nn_ids = np.concatenate([nn_results[subject_id, subject_id][1][voxel_id] for voxel_id in voxel_ids])
            
            dist_mask = nn_dists < eps
            nn_ids = nn_ids[dist_mask]
            expanded_clusters[c_id][nn_ids] = 1
                        
            expanded_clusters[c_id][expanded_clusters[c_id] != 1] = 0
            expanded_clusters[c_id][cluster_mask] = 1
            
        expanded_clusters_all.append(expanded_clusters)
    return expanded_clusters_all

def modified_dbscan(min_neighbors, eps, expansion_eps):
    # initialize cluster_id maps as -2 (unexplored)
    cluster_ids = [np.full(W[subject_id].shape[0], UNEXPLORED) for subject_id in range(8)]
    cluster_id = 0
    core_samples_all = []
    
    for subject_i in range(8):
        for voxel_id in range(W[subject_i].shape[0]):
            if cluster_ids[subject_i][voxel_id] != UNEXPLORED:
                continue
                
            core_samples = cross_participant_clustering(cluster_ids, cluster_id, subject_i, voxel_id, eps, min_neighbors)

            if len(core_samples) > 0:
                core_samples_all.append(np.array(core_samples))
                #print(core_samples)
                cluster_id += 1
    
    expanded_clusters_all = distance_expansion(core_samples_all, cluster_ids, expansion_eps)
            
    return cluster_ids, core_samples_all, expanded_clusters_all

def save_results(cluster_name, min_neighbors, eps, reruns_mode, cluster_ids, core_samples_all, expanded_clusters, expansion_suffix):
    root_path =  nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}/'
    root_path.mkdir(exist_ok=True, parents=True)
    out_path = nsd_path / f'derivatives/figures/concept_maps_voxel_v5/{tag}/{cluster_name}'
    out_path.mkdir(exist_ok=True, parents=True)

    params = {'min_neighbors': min_neighbors, 'eps': eps, 'reruns_mode': reruns_mode}
    with open(out_path / 'params.json', 'w') as f:
        f.write(json.dumps(params))

    if reruns_mode == 'multiple':
        for subj_id, subject_name in enumerate(subjects):
            c_ids = np.stack(np.split(cluster_ids[subj_id] + 1, num_models))
            c_mask = F.one_hot(torch.from_numpy(c_ids).long(), num_classes=len(core_samples_all) + 1).float().numpy()[:, :, 1:].mean(axis=0)
            
            c_mask_expanded = np.stack(np.split(expanded_clusters[subj_id], num_models, axis=1)).mean(axis=0)
            
            subject_path = out_path / subject_name
            subject_path.mkdir(exist_ok=True, parents=True)
            np.save(subject_path / 'mask.npy', c_mask.T)
            np.save(subject_path / f'mask{expansion_suffix}.npy', c_mask_expanded)
    else:
        #W_concat = np.concatenate([w for w in W])
        for subj_id, subject_name in enumerate(subjects):
            c_ids = np.stack(np.split(cluster_ids[subj_id] + 1, num_models))
            c_mask = F.one_hot(torch.from_numpy(c_ids).long(), num_classes=len(core_samples_all) + 1).float().numpy()[:, :, 1:].mean(axis=0)
            subject_path = out_path / subject_name
            subject_path.mkdir(exist_ok=True, parents=True)
            print(c_mask.T.shape)
            np.save(subject_path / 'mask.npy', c_mask.T)
            np.save(subject_path / f'mask{expansion_suffix}.npy', expanded_clusters[subj_id])
            
        subject_id = np.concatenate([np.full((w.shape[0],), i, int) for i, w in enumerate(W)])
        W_t = np.concatenate(cluster_ids)

        #np.save(root_path / f'{tag}__W.npy', W_concat)
        #np.save(root_path / f'{tag}__subject_id.npy', subject_id)
        np.save(out_path / f'{tag}__clusters.npy', W_t)

        for subj_id, subject_name in enumerate(subjects):
            subject_mask = subject_id == subj_id
            #np.save(root_path / f'{tag}__W__{subject_name}.npy', W_concat[subject_mask])
            np.save(out_path / f'{tag}__clusters__{subject_name}.npy', W_t[subject_mask])

In [None]:
# Run modified DBSCAN clustering algorithm for desired min_neighbors and eps combinations

for min_neighbors in (1, 2, 3, 4):
    for i, eps in enumerate((0.55, 0.6, 0.65, 0.5, 0.7, 0.45)):
        if min_neighbors != 3:
            continue
        print(f'{min_neighbors=}, {eps=}')
        cluster_ids, core_samples_all, expanded_clusters_all = modified_dbscan(min_neighbors, eps, expansion_eps=min(eps + 0.05, 0.65))
        save_results(
            f'num_models-{num_models}/min_neighbors-{min_neighbors}/run-{i}', min_neighbors, eps, reruns_mode, cluster_ids, core_samples_all, expanded_clusters_all, 
            expansion_suffix='_expanded'
        )
        
        for i, core_samples in enumerate(core_samples_all):
            print('cluster_id', i)
            print([((i == c_ids).sum()) for subject_id, c_ids in enumerate(cluster_ids)])
            print([expanded_clusters_all[subject_id][i].sum() for subject_id in range(8)])