# Imports

In [None]:
!pip install scipy==1.7

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting scipy==1.7
  Downloading scipy-1.7.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (28.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28.4/28.4 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.10.1
    Uninstalling scipy-1.10.1:
      Successfully uninstalled scipy-1.10.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
arviz 0.15.1 requires scipy>=1.8.0, but you have scipy 1.7.0 which is incompatible.[0m[31m
[0mSuccessfully installed scipy-1.7.0


In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install -q git+https://github.com/tknapen/nsd_access
!pip install -q git+https://github.com/PalaashAgrawal/NSD_exploration.git
!pip install -q nibabel==4.0.2

[31mERROR: Operation cancelled by user[0m[31m
[0mTraceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/cli/base_command.py", line 160, in exc_logging_wrapper
    status = run_func(*args)
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/cli/req_command.py", line 241, in wrapper
    return func(self, options, args)
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/commands/install.py", line 419, in run
    requirement_set = resolver.resolve(
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 73, in resolve
    collected = self.factory.collect_root_requirements(root_reqs)
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/resolution/resolvelib/factory.py", line 491, in collect_root_requirements
    req = self._make_requirement_from_install_req(
  File "/usr/local/lib/python3.9/dist-packages/pip/_internal/resolution/resolvelib/factory.py", line 453, in _make_req

In [None]:
import nibabel as nib

# NSD Explore

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import h5py
from scipy.stats import pearsonr
import pickle
from tqdm import tqdm
import nibabel as nib
from nsd_access import NSDAccess
import os

def get_stim(subject, info_file_path, fmri_path):
	info_file = pd.read_csv(info_file_path)
	info_file = info_file.loc[info_file['subject'+str(subject)] == 1]

	fmri_order_list=[]
	for i,vals in (info_file.iterrows()): 
		order = vals[f'subject{subject}_rep0']
		session = order//750 + 1
		if os.path.exists(fmri_path+f'/betas_session{session:02d}.hdf5'):
			fmri_order_list.append(order)
	stim = sorted(fmri_order_list)
	return stim 

def get_chunkified_stimulus_order(stimulus_order:dict):
	f'Its easier to access all the files from a hdf5 file at once rather than opening the hdf5 file repeatedly for each stimulus. Helper function for that'
	chunks = {}
	for order in stimulus_order:
		session = order//750 + 1
		if session not in chunks: 
			chunks[session] = [(order-1)%750]
		else: 
			chunks[session].append((order-1)%750)

	
	for session in sorted(chunks.keys()): 
		yield session,chunks[session]

def apply_tfms(x, tfms: list):
	for tfm in tfms: 
		x = tfm(x)
	return x

def noop(x): return x


def get_fmri_list(stimulus_order:dict, tfms = None): 
	f'''
	get list of fmri's from the hdf5 files, on which tfms are applied (eg ROIs),
	given the stimulus_order dictionary which corresponds to the subject "subj" is being considered
	'''
	tfms = list(tfms) if tfms else [lambda x: x]
	fmri_list = []
	for session, indices in get_chunkified_stimulus_order(stimulus_order):
		fmri_file = fmri_path+f'/betas_session{session:02d}.hdf5'
		if True:
			with h5py.File(fmri_file,"r") as f: 
				beta = f['betas'][()]
				fmris = apply_tfms(beta[indices], tfms)
				fmri_list.extend(list(fmris))
	return fmri_list


roi_index = {
	'v1v':1.,
	'v1d':2.,
	'v2v':3.,
	'v2d':4.,
	'v3v':5.,
	'v3d':6.,
	'v4':7. #hv4
}



def get_roi_data(roi_file_paths):
	roi_file_paths = list(roi_file_paths)
	return [nib.load(roi_path).get_fdata().transpose(2,1,0) for roi_path in roi_file_paths]

def get_roi_regions(roi_file_index, region, pathway):
	if pathway is None: pathway = ''
	keys = [key for key in roi_file_index if key.startswith(region+pathway)]
	print('keys extracted: ',keys)
	return [roi_file_index[key] for key in keys]


class apply_roi():
	def __init__(self, roi_file_paths,roi_file_index = roi_index, region = 'v1', pathway = None):
		f'''
		roi_file_paths( str or Path or list of str/Paths): path of roi file (eg. <path>/'lh.prf-visualrois.nii.gz'). Many a times, ROI files corresponding to any region are present as multiple files. 
		eg: lh.prf-visualrois and rh.prf-visualrois. You can provide a list of paths, provided the roi_file_index mapping is common for all of them. 
        
		roi_file_index: dictionary mapping region name to index in the roi file. 
		region (str): v1,v2,v3 or v4
		pathway: None, 'v' (ventral) or 'd' (dorsal). If pathway =None, both ventral and dorsal are chosen
		'''
		
		#Assertions
		assert roi_file_paths is not None
		if pathway: assert pathway in ('v','d'), f"invalid pathway. Should be None, 'v' or 'd'"
		if region=='v4': assert pathway is None, f"ventral or dorsal pathways dont exist for the hv4 region. Use pathway=None"
		#____________________________________________________________________________________________________
		self.roi =  sum((roi==reg).astype(float) for roi in get_roi_data(roi_file_paths) for reg in get_roi_regions(roi_file_index,region, pathway))

	def __call__(self,fmri):
		return np.multiply(fmri,self.roi)



def construct_RDM(activations):
	num_images = len(activations)
	RDM = np.zeros((num_images, num_images))

	for x in range(num_images):
		for y in range(num_images):
			if x<=y:
				correl = 1 - (pearsonr(activations[x].flatten(), activations[y].flatten()))[0]
				RDM[x][y] = correl
				RDM[y][x] = correl
	return RDM.astype(float)




# with open(path/f'RDM/RDM_subj{subject}_{region}.pkl','wb') as f: pickle.dump(rdm, f)


In [None]:
subject = 5
fmri_path = '/content/drive/MyDrive/Thesis/Data/NSD/nsddata_betas/ppdata/subj0'+str(subject)+'/func1pt8mm/betas_fithrf_GLMdenoise_RR'
info_file_path = '/content/drive/MyDrive/Thesis/Data/NSD/nsd_coco_cifar_pre.csv' 
roi_path = '/content/drive/MyDrive/Thesis/Data/NSD/nsddata/ppdata/subj0'+str(subject)+'/func1pt8mm/roi'

regions = ['v1','v2','v3','v4']
fmri_rdms = []

for region in regions:    
    pathway = None if region=='v4' else 'v'
    
    stim = get_stim(subject, info_file_path, fmri_path)
    roi_paths = [roi_path+'/lh.prf-visualrois.nii.gz',roi_path+'/rh.prf-visualrois.nii.gz']
    tfms = apply_roi(roi_paths, region = region, pathway = pathway)
    fmri_activations = get_fmri_list(stim, tfms = [tfms])
    fmri_rdm = construct_RDM(fmri_activations)
    fmri_rdms.append(fmri_rdm)


keys extracted:  ['v1v']
keys extracted:  ['v1v']
keys extracted:  ['v2v']
keys extracted:  ['v2v']
keys extracted:  ['v3v']
keys extracted:  ['v3v']
keys extracted:  ['v4']
keys extracted:  ['v4']


In [None]:
nsd_df = pd.read_csv('/content/drive/MyDrive/Thesis/Data/NSD/nsd_coco_cifar_pre.csv',)

NSD = NSDAccess('/content/drive/MyDrive/Thesis/Data/NSD')

def getImages(nsd_df, subject):
    nsd_df = nsd_df.loc[nsd_df['subject'+str(subject)] == 1]
    nsd_ids = list(nsd_df['nsdId'])
    nsd_ids.sort()

    for id in nsd_ids:
        order = nsd_df.loc[nsd_df['nsdId'] == id][f'subject{subject}_rep0'].iloc[0]
        session = order//750 + 1
        if not os.path.exists(fmri_path+f'/betas_session{session:02d}.hdf5'):
            nsd_ids.remove(id)
    images = NSD.read_images(nsd_ids, show=False)
    print(nsd_ids)
    return images

In [None]:
images = getImages(nsd_df, subject)

[306, 335, 442, 642, 741, 1314, 1753, 2656, 2945]


# CNN

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torchvision
from torchvision import transforms

CNN_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(3, 90, 3)
        self.conv2 = nn.Conv2d(90, 145, 2)
        self.conv3 = nn.Conv2d(145, 281, 2)
        self.conv4 = nn.Conv2d(281, 50, 5)
        self.fc1 = nn.Linear(50 * 2 * 2, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.pool(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        x = torch.ravel(x)
        x = self.fc1(x)
        return x


In [None]:
CNN = net = Net()
CNN.load_state_dict(torch.load('/content/drive/MyDrive/Thesis/checkpoints/4layer_70.pth'))

<All keys matched successfully>

## Get Activations

In [None]:
from PIL import Image
import numpy as np

# Reshape images and transform
imgs_retran = []
for image in images:
    image = Image.fromarray(np.uint8(image)).convert('RGB')
    image.thumbnail((32,32), Image.ANTIALIAS)
    imgs_retran.append(CNN_transform(np.asarray(image)))

In [None]:
CNN_activations = []

layers = list(CNN.named_children())[1:5]

for layer in layers:
    layer_activations = []
    activation = {}

    # Puts model in evaluation mode
    CNN.eval()

    # Activation hook function
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    # Register the hook to the last CNN layer 
    #(this will get activations coming out of the last CNN layer)
    x = list(CNN.named_children())
    layer[1].register_forward_hook(get_activation(layer[0]))

    # Pass COCO images into model
    for image in imgs_retran:
        # Turn off gradient calcs just in case! (even though its in eval())
        with torch.no_grad():
            CNN(image)

        # Save activations for this batch
        layer_activations.append(activation[layer[0]].numpy())
    CNN_activations.append(layer_activations)

# CSNN

In [None]:
!pip install -q git+https://github.com/miladmozafari/SpykeTorch.git

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for SpykeTorch-miladmozafari (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torchvision
import numpy as np
from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils
from torchvision import transforms

def time_dim(input):
    return input.unsqueeze(0)

CSNN_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     time_dim,
     sf.pointwise_inhibition,
     utils.Intensity2Latency(number_of_spike_bins = 15, to_spike = True)])

## Model

In [None]:
class CSNN(nn.Module):
    def __init__(self):
        super(CSNN, self).__init__()

        self.n_layers = 4

        # Pooling Layer
        self.pool = snn.Pooling(kernel_size = 3, stride = 1)

        # Conv Layers
        self.conv1 = snn.Convolution(in_channels=3, out_channels=64, kernel_size=4)
        self.conv2 = snn.Convolution(in_channels=64, out_channels=128, kernel_size=3)
        self.conv3 = snn.Convolution(in_channels=128, out_channels=256, kernel_size=3)
        self.conv4 = snn.Convolution(in_channels=256, out_channels=100, kernel_size=1)

        # STDP Functions
        self.stdp1 = snn.STDP(conv_layer = self.conv1, learning_rate = (0.1, -0.005))  #0.1, -0.005
        self.stdp2 = snn.STDP(conv_layer = self.conv2, learning_rate = (0.07, -0.005)) #0.09, -0.005
        self.stdp3 = snn.STDP(conv_layer = self.conv3, learning_rate = (0.05, -0.005))  #0.08 -0.005
        self.stdp4 = snn.STDP(self.conv4, (0.5, -0.05), False, 0.2) #(0.007, -0.01)
        self.anti_stdp = snn.STDP(self.conv4, (-0.0002, 0.075), False, 0.2) # (-0.002, 0.0001)

        # Ignore! -> parameter is only used for training
        self.max_ap = Parameter(torch.Tensor([0.3]))

        # Outputs
        self.outputs = {'conv1':[], 'conv2':[], 'conv3':[], 'conv4':[]}

    def forward(self, x, t, layer):
        x = self.pool(x)
        p = self.conv1(x)
        o, p = sf.fire(p, 10, return_thresholded_potentials=True)
        self.store_pot_out(o, 'conv1')
        x = o
        p = self.conv2(x)
        o, p = sf.fire(p, 5, return_thresholded_potentials=True)
        self.store_pot_out(o, 'conv2')
        x = o
        p = self.conv3(x)
        o, p = sf.fire(p, 2, return_thresholded_potentials=True)
        self.store_pot_out(o, 'conv3')
        p = self.conv4(o)
        o = sf.fire(p,2)
        self.store_pot_out(o, 'conv4')
        winners = sf.get_k_winners(p, kwta=1, inhibition_radius=0, spikes=o)
        return winners
    
    def store_pot_out(self, out, layer):
        self.outputs[layer].append(out)


    def reset(self):
        self.conv.reset_weight    

    def stdp(self, layer, x, p, o, winners):
        if layer == 1:
            self.stdp1(x, p, o, winners)
        elif layer == 2:
            self.stdp2(x, p, o, winners)
        elif layer == 3:
            self.stdp3(x, p, o, winners)
        else:
            self.stdp4(x, p, o, winners)


    def update_learning_rates(self, stdp_ap, stdp_an, anti_stdp_ap, anti_stdp_an):
        self.stdp4.update_all_learning_rate(stdp_ap, stdp_an)
        self.anti_stdp.update_all_learning_rate(anti_stdp_an, anti_stdp_ap)

In [None]:
csnn = CSNN()

# l1_dir = '/content/drive/MyDrive/Thesis/checkpoints/saved_l1.net'
# l2_dir = '/content/drive/MyDrive/Thesis/checkpoints/saved_l2.net'
# l3_dir = '/content/drive/MyDrive/Thesis/checkpoints/saved_l3.net'
# l4_dir = '/content/drive/MyDrive/Thesis/checkpoints/saved_l4.net'

# csnn.load_state_dict(torch.load(l1_dir, map_location=torch.device('cpu')))
# csnn.load_state_dict(torch.load(l2_dir, map_location=torch.device('cpu')))
# csnn.load_state_dict(torch.load(l3_dir, map_location=torch.device('cpu')))
# csnn.load_state_dict(torch.load(l4_dir, map_location=torch.device('cpu')))

## Get Activations

In [None]:
from PIL import Image
import numpy as np

# Reshape images and transform
imgs_retran = []
for image in images:
    image = Image.fromarray(np.uint8(image)).convert('RGB')
    image.thumbnail((32,32), Image.ANTIALIAS)
    imgs_retran.append(CSNN_transform(np.asarray(image)))

In [None]:
csnn.eval()

for image in imgs_retran:
    csnn(image,None,4)

CSNN_outs = csnn.outputs


# Compare RDMs

In [None]:
# Code adapted from https://github.com/ColinConwell/DeepDive/blob/main/deepdive/mapping_methods.py 
def compare_rdms(rdm1, rdm2):
    rdm1_triu = rdm1[np.triu_indices(rdm1.shape[0], k=1)]
    rdm2_triu = rdm2[np.triu_indices(rdm2.shape[0], k=1)]
    
    return pearsonr(rdm1_triu, rdm2_triu)[0]


## CNN Compare

In [None]:
def CNN_Compare(initial=False):
    CNN_rdm_scores = []
    if initial:
        labels = []

    for i, fmri_rdm in enumerate(fmri_rdms):
        for j, activations in enumerate(CNN_activations):
            CNN_rdm = construct_RDM(activations)
            CNN_rdm_scores.append(compare_rdms(fmri_rdm, CNN_rdm))
            if initial:
                labels.append((regions[i],layers[j][0]))
    if initial:
        d = {'region':[l[0] for l in labels],'layer':[l[1] for l in labels],'subj'+str(subject)+'_rdm_score':CNN_rdm_scores}
        init_df = pd.DataFrame(data=d)
        init_df.to_csv('/content/drive/MyDrive/Thesis/Data/NSD/CNN_dm_scores.csv', index=False)

    else:
        CNN_scores_csv = pd.read_csv('/content/drive/MyDrive/Thesis/Data/NSD/CNN_rdm_scores.csv', index_col=False)
        CNN_scores_csv['subj'+str(subject)+'_rdm_score'] = CNN_rdm_scores
        CNN_scores_csv.to_csv('/content/drive/MyDrive/Thesis/Data/NSD/CNN_rdm_scores.csv', index=False)

In [None]:
CNN_Compare(False)

## CSNN Compare

In [None]:
def CSNN_Compare(initial=False):
    CSNN_rdm_scores = []
    if initial:
        labels = []

    for i, fmri_rdm in enumerate(fmri_rdms):
        for layer in CSNN_outs:
            CSNN_rdm = construct_RDM(CSNN_outs[layer])
            CSNN_rdm_scores.append(compare_rdms(fmri_rdm, CSNN_rdm))
            if initial:
                labels.append((regions[i],layer))
    if initial:
        d = {'region':[l[0] for l in labels],'layer':[l[1] for l in labels],'subj'+str(subject)+'_rdm_score':CSNN_rdm_scores}
        init_df = pd.DataFrame(data=d)
        init_df.to_csv('/content/drive/MyDrive/Thesis/Data/NSD/CSNN_rdm_scores.csv', index=False)

    else:
        CSNN_scores_csv = pd.read_csv('/content/drive/MyDrive/Thesis/Data/NSD/CSNN_rdm_scores.csv', index_col=False)
        CSNN_scores_csv['subj'+str(subject)+'_rdm_score'] = CSNN_rdm_scores
        CSNN_scores_csv.to_csv('/content/drive/MyDrive/Thesis/Data/NSD/CSNN_rdm_scores.csv', index=False)

## Run Compare

In [None]:
CSNN_Compare(False)

# Formulate Results

## CNN

In [None]:
import pandas as pd

In [None]:
CNN_scores_csv = pd.read_csv('/content/drive/MyDrive/Thesis/Data/NSD/CNN_rdm_scores.csv', index_col=False)

CNN_scores_csv['avg'] = CNN_scores_csv.iloc[:,2:].mean(axis=1)

res = list(round(CNN_scores_csv.loc[CNN_scores_csv['region'] == 'v3']['avg'],4))
# sum(res)/len(res)

res 

[-0.0402, -0.0696, -0.0102, -0.0312]

## CSNN


In [None]:
CSNN_scores_csv = pd.read_csv('/content/drive/MyDrive/Thesis/Data/NSD/CSNN_rdm_scores.csv', index_col=False)

CSNN_scores_csv['avg'] = CSNN_scores_csv.iloc[:,2:].mean(axis=1)

res = list(round(CSNN_scores_csv.loc[CSNN_scores_csv['region'] == 'v3']['avg'],4))

res

[0.001, 0.0058, 0.0195, 0.0195]

In [None]:
CNN_scores_csv

Unnamed: 0,region,layer,subj1_rdm_score,subj2_rdm_score,subj3_rdm_score,subj5_rdm_score,subj6_rdm_score,subj7_rdm_score,subj4_rdm_score,subj8_rdm_score,avg
0,v1,conv1,0.311771,-0.126675,-0.175067,0.031053,-0.014009,0.151623,0.102377,-0.115993,0.020635
1,v1,conv2,-0.070362,0.046552,-0.014834,-0.312925,0.10155,0.083062,-0.054842,-0.042916,-0.033089
2,v1,conv3,0.121381,-0.074245,0.122989,-0.23867,-0.141318,0.021854,0.480314,0.032313,0.040577
3,v1,conv4,0.021773,0.055425,-0.056948,-0.043401,-0.058174,-0.052108,0.223165,-0.009381,0.010044
4,v2,conv1,0.258415,-0.096434,-0.136211,0.007975,0.00741,0.164984,0.132929,-0.150299,0.023596
5,v2,conv2,-0.04183,-0.01435,0.046855,-0.312665,0.063209,0.097401,0.076488,0.035176,-0.006215
6,v2,conv3,0.134318,-0.118613,0.087072,-0.151564,-0.130425,0.040965,0.616859,-0.017726,0.057611
7,v2,conv4,0.074217,0.025269,-0.050965,-0.015741,-0.039389,-0.009985,0.404086,0.003121,0.048827
8,v3,conv1,0.161547,-0.071932,-0.153216,-0.242585,0.101786,0.098103,-0.140755,-0.074368,-0.040178
9,v3,conv2,-0.083671,-0.062212,0.034185,-0.415852,0.10077,0.083989,-0.210008,-0.004033,-0.069604
