In [1]:
from rsatoolbox import vis
from rsatoolbox import rdm
import rsatoolbox
import rsatoolbox.data as rsd 
import rsatoolbox.rdm as rsr

import numpy as np
import os
import inspect
import scipy.io
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
import glob

from sklearn.decomposition import IncrementalPCA
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split, Subset
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torchvision import transforms

from util_function import ImageDataset, get_roi_mapping, seed_everything, image_visualization_transform

In [2]:
brain_rdm_dir = '/home/yuchen/brain_rdm_new/'
model_rdm_dir = '/home/yuchen/model_rdm/'

In [3]:
ls /home/yuchen/model_rdm/

[0m[01;34malexnet[0m/          [01;34mmobilenet_v2_new[0m/         [01;34mresnet50_swav_new[0m/
[01;34malexnet_new[0m/      [01;34mresnet18[0m/                 [01;34mresnet_swav[0m/
[01;34mdensenet161[0m/      [01;34mresnet18_new[0m/             [01;34msqueezenet1_0[0m/
[01;34minceptionv3[0m/      [01;34mresnet50[0m/                 [01;34msqueezenet1_0_new[0m/
[01;34minceptionv3_new[0m/  [01;34mresnet50_new[0m/             [01;34mvgg16[0m/
[01;34mmobilenet_v2[0m/     [01;34mresnet50_swav_allimages[0m/  [01;34mvgg16_new[0m/


In [23]:
lst = glob.glob(model_rdm_dir+'inceptionv3_new/'+'*')
model_rdm_dir_map = {}

for file in lst:
    key = int(file.split('/')[-1].split('_')[0])
    model_rdm_dir_map[key] = file

model_rdm_dir_map = {k: model_rdm_dir_map[k] for k in sorted(model_rdm_dir_map)}

In [24]:
roi = ["V1", "V2", "V4",
       "EBA", "FBA", "OFA", "FFA",
       "OPA", "PPA", "RSC","OWFA", "VWFA" ]

model_name = ''

performance = {}


for roi_ in roi:
    performance[roi_] = [0,'']
    brain_rdm = np.load(f'{brain_rdm_dir}{roi_}.npy')
    for model_layer_dir in model_rdm_dir_map.values():
       model_rdm = np.load(f'{model_layer_dir}')
       
       correlation = rsatoolbox.rdm.compare_correlation(model_rdm, brain_rdm)
       
       if correlation > performance[roi_][0]:
              if 'maxpool' not in model_layer_dir and 'flatten' not in model_layer_dir and 'classifier' not in model_layer_dir and 'avgpool' not in model_layer_dir: 
                     performance[roi_] = [np.round(np.nanmean(correlation),5),model_layer_dir]
       
       print(roi_, model_layer_dir.split('/')[-1], np.round(np.nanmean(correlation),5))

V1 13_Conv2d_1a_3x3.conv.npy 0.08485
V1 14_Conv2d_1a_3x3.bn.npy 0.00408
V1 15_Conv2d_1a_3x3.relu.npy -0.01607
V1 16_Conv2d_2a_3x3.conv.npy -0.15272
V1 17_Conv2d_2a_3x3.bn.npy 0.05882
V1 18_Conv2d_2a_3x3.relu.npy -0.04556
V1 19_Conv2d_2b_3x3.conv.npy -0.04421
V1 20_Conv2d_2b_3x3.bn.npy 0.10709
V1 21_Conv2d_2b_3x3.relu.npy -0.26622
V1 22_maxpool1.npy 0.31037
V1 23_Conv2d_3b_1x1.conv.npy 0.16331
V1 24_Conv2d_3b_1x1.bn.npy 0.14068
V1 25_Conv2d_3b_1x1.relu.npy -0.17285
V1 26_Conv2d_4a_3x3.conv.npy -0.19774
V1 27_Conv2d_4a_3x3.bn.npy 0.0972
V1 28_Conv2d_4a_3x3.relu.npy -0.14825
V1 29_maxpool2.npy 0.38557
V1 30_Mixed_5b.branch1x1.conv.npy 0.23597
V1 31_Mixed_5b.branch1x1.bn.npy 0.10501
V1 32_Mixed_5b.branch1x1.relu.npy -0.03465
V1 33_Mixed_5b.branch5x5_1.conv.npy 0.23892
V1 34_Mixed_5b.branch5x5_1.bn.npy 0.10133
V1 35_Mixed_5b.branch5x5_1.relu.npy 0.07946
V1 36_Mixed_5b.branch5x5_2.conv.npy 0.05062
V1 37_Mixed_5b.branch5x5_2.bn.npy 0.08265
V1 38_Mixed_5b.branch5x5_2.relu.npy -0.00099
V1 39_Mi

In [25]:
for key in performance.keys():
    print(key, performance[key][0],  performance[key][1].split('.npy')[0].split('/')[-1])

V1 0.43171 48_Mixed_5b.avg_pool2d
V2 0.50888 48_Mixed_5b.avg_pool2d
V4 0.30155 48_Mixed_5b.avg_pool2d
EBA 0.10988 102_Mixed_6a.branch3x3dbl_1.conv
FBA 0.10265 227_Mixed_6e.branch7x7dbl_3.conv
OFA 0.17857 125_Mixed_6b.branch7x7dbl_1.conv
FFA 0.10757 227_Mixed_6e.branch7x7dbl_3.conv
OPA 0.1121 48_Mixed_5b.avg_pool2d
PPA 0.08585 116_Mixed_6b.branch7x7_1.conv
RSC 0.06585 151_Mixed_6c.branch7x7_2.conv
OWFA 0.19838 125_Mixed_6b.branch7x7dbl_1.conv
VWFA 0.08766 104_Mixed_6a.branch3x3dbl_1.relu


In [26]:
for key in performance.keys():
    print(performance[key][0])

0.43171
0.50888
0.30155
0.10988
0.10265
0.17857
0.10757
0.1121
0.08585
0.06585
0.19838
0.08766


In [28]:
for key in performance.keys():
    print(performance[key][1].split('.npy')[0].split('/')[-1])

48_Mixed_5b.avg_pool2d
48_Mixed_5b.avg_pool2d
48_Mixed_5b.avg_pool2d
102_Mixed_6a.branch3x3dbl_1.conv
227_Mixed_6e.branch7x7dbl_3.conv
125_Mixed_6b.branch7x7dbl_1.conv
227_Mixed_6e.branch7x7dbl_3.conv
48_Mixed_5b.avg_pool2d
116_Mixed_6b.branch7x7_1.conv
151_Mixed_6c.branch7x7_2.conv
125_Mixed_6b.branch7x7dbl_1.conv
104_Mixed_6a.branch3x3dbl_1.relu
