In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH'))
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import glob
from scipy import stats
from scipy.spatial.distance import squareform
from net2brain.feature_extraction import FeatureExtractor, all_networks

from net2brain.utils.download_datasets import DatasetAlgonauts_NSD
import math

#local imports
from sklearn.decomposition import PCA, IncrementalPCA

from src.utils.transforms import SelectROIs
from src.encoding_exp.encoding_utils.models.model import RegressionAlexNet, EncoderMultiHead, Encoder, C8NonSteerableCNN
from src.encoding_exp.neuralpredictors.layers.readouts import SpatialXFeatureLinear
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from tqdm import tqdm
import json
import numpy as np
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import yaml
import pickle
from collections import defaultdict
from pathlib import Path

#local imports
from src.utils.transforms import SelectROIs
from src.encoding_exp.encoding_utils.models.model import RegressionAlexNet, EncoderMultiHead, Encoder, C8NonSteerableCNN, AlexNetCore
from src.encoding_exp.neuralpredictors.layers.readouts import SpatialXFeatureLinear
from src.encoding_exp.neuralpredictors.layers.readouts.factorized import FullLinearReadout
from src.utils.dataset import StimulusDataset
from src.utils.helpers import FilterDataset

In [None]:
def get_lowertriangular(rdm):
    num_conditions = rdm.shape[0]
    return rdm[np.triu_indices(num_conditions,1)]

def visualize_RDM(rdm, savefig=False):
    rdm_rank = stats.rankdata(get_lowertriangular(rdm))
    rdm_rank_norm = rdm_rank/rdm_rank.max()
    rdm_rank_square = squareform(rdm_rank_norm)
    plt.imshow(rdm_rank_square, cmap='jet')
    plt.colorbar()
    if savefig:
        plt.savefig(savefig)
        
    plt.show()
    plt.clf()

def computeRDM(activations_dict, do_pca=True, normalize=True):
    ncond = len(activations_dict)
    embeddings = []
    filenames = [] #order of filenames
    for filename, activation in activations_dict.items():
        embeddings.append(activation.flatten())
        filenames.append(filename)
    arr = np.array(embeddings)
    if do_pca:
        print("running pca")
        n_components=100
        pca = PCA(n_components=n_components)
        transformed_arr = pca.fit_transform(arr)
        arr = transformed_arr[:, :n_components]
    print(f"array shape: {arr.shape}")

    rdm = 1 - cosine_similarity(arr)
    
    assert len(filenames) == ncond, f"number of filenames is {len(filenames)}, should be {ncond}"
    assert rdm.shape == (ncond, ncond), f"shape of rdm is {rdm.shape}, should be ({ncond}, {ncond})"
    if normalize:
        rdm_rank = stats.rankdata(get_lowertriangular(rdm))
        rdm_rank_norm = rdm_rank/rdm_rank.max()
        rdm = squareform(rdm_rank_norm)

    return filenames, rdm

In [None]:
project_root = os.getenv("PROJECT_ROOT", "/default/path/to/datasets") #use default if DATASETS_ROOT env variable is not set.
dataset_root = os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"), "MOSAIC") #use default if DATASETS_ROOT env variable is not set.
os.environ["CUDA_VISIBLE_DEVICES"] = "7" #use a different GPU than other programs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
stimulus_paths = sorted(glob.glob(os.path.join("/data/vision/oliva/datasets/kingdaka/stimuli", "*.jpg")))
stimulus_paths.sort()


In [None]:
#download dataset
#NSD_dataset = DatasetAlgonauts_NSD()  # or use DatasetAlgonauts_NSD
#paths = NSD_dataset.load_dataset(path=os.path.join(project_root, "src", "stimulusSetPreparation", "extract_embeddings", "nsd_test_tmp"))


In [None]:
from net2brain.feature_extraction import all_networks
print(all_networks['Taskonomy'])


In [None]:
for netset in all_networks.keys():
    if netset != 'Taskonomy':
        continue
    for model_name in all_networks[netset]:
        fx_model = FeatureExtractor(model=model_name,
                                    netset=netset,
                                    device='cpu')
        print(f"{model_name}: {fx_model.get_all_layers()}")

In [None]:

data_dict = defaultdict(defaultdict)
for netset in all_networks.keys():
    if netset != 'Taskonomy':
        continue
    for model_name in all_networks[netset]:
        fx_model = FeatureExtractor(model=model_name,
                                    netset=netset,
                                    device='cpu')
        fx_model.get_all_layers()
        #layers_to_extract = ["features.10"] #["features.0", "features.1", "features.2",  "features.3",  "features.4", "features.5",  "features.6", "features.7",  "features.8", "features.9",  "features.10",  "features.11", "features.12", "classifier.0", "classifier.1", "classifier.2", "classifier.3", "classifier.4", "classifier.5", "classifier.6"] #['layer1.2.conv3', 'layer2.7.conv3', 'layer3.35.conv3', 'layer4.2.conv3'] #['layer1.2.relu','layer2.3.relu', 'layer3.5.relu', 'layer4.2.relu',]
        #fx_model.layers_to_extract = layers_to_extract
        ft_path = f'feats3_{model_name}'
        fx_model.extract(data_path=stimulus_paths,
                        save_path=os.path.join(project_root, "src", "encoding_exp", "net2brain", ft_path),
                        layers_to_extract=["layer4.2.conv3"]) #layers_to_extract)
        npzFiles = glob.glob(os.path.join(project_root, "src", "encoding_exp",  "net2brain", ft_path, "*.npz"))
        for f in npzFiles:
            d = np.load(f)
            for stim in range(1,157):
                data_dict[f"{ft_path}"][f"{stim:03d}"] = d[f"{stim:03d}"].flatten()

In [None]:
res = data_dict
# Calculate the number of layers
num_layers = len(res.keys())

# Calculate optimal grid dimensions to make it as square as possible
nrows = math.ceil(math.sqrt(num_layers))
ncols = math.ceil(num_layers / nrows)

# Create a figure with subplots in a grid
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*4, nrows*4))

# Flatten the axes array for easier indexing
axes = axes.flatten()

# Find global min and max for consistent color scaling
all_rdms = []
for layer in res.keys():
    print(layer)
    _, rdm = computeRDM(res[layer], do_pca=False, normalize=True)
    all_rdms.append(rdm)
    
global_min = min(np.min(rdm) for rdm in all_rdms)
global_max = max(np.max(rdm) for rdm in all_rdms)

# Define tick positions and labels
tick_positions = [27, 63, 99, 123, 155]
tick_labels = ["Animals", "Objects", "Scenes", "People", "Faces"]

# Plot each layer's RDM in a separate subplot
for i, layer in enumerate(res.keys()):
    if i >= nrows * ncols:  # Skip if we've run out of subplots
        print(f"Warning: Not enough subplots for layer {layer}. Increase nrows*ncols.")
        break
        
    rdm = all_rdms[i]
    
    # Plot the matrix with consistent color scaling
    im = axes[i].imshow(rdm, cmap='jet', vmin=global_min, vmax=global_max)
    
    # Add title
    axes[i].set_title(f'{layer} RDM')
    axes[i].set_xlabel('X-axis')
    axes[i].set_ylabel('Y-axis')
    
    # Set custom tick positions and labels
    axes[i].set_xticks(tick_positions)
    axes[i].set_yticks(tick_positions)
    axes[i].set_xticklabels(tick_labels, rotation=45, ha='right')
    axes[i].set_yticklabels(tick_labels)

# Hide any unused subplots
for j in range(i+1, nrows*ncols):
    axes[j].axis('off')

# Add a single colorbar for all subplots
if num_layers > 0:  # Only add colorbar if we have at least one plot
    # Adjust colorbar position based on grid size
    cbar_width = 0.02
    cbar_padding = 0.05
    cbar_left = 1.0 - cbar_width - 0.01
    cbar_height = 0.7
    cbar_bottom = 0.15
    
    cbar_ax = fig.add_axes([cbar_left, cbar_bottom, cbar_width, cbar_height])  
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('Normalized Distance')
    
    # Adjust layout to prevent overlap - make room for colorbar
    rect_right = 1.0 - (cbar_width + cbar_padding)
    plt.tight_layout(rect=[0, 0, rect_right, 1])
else:
    plt.tight_layout()
plt.show()