In [None]:
import os
import time
import random

import numpy as np
from PIL import Image
import tensorflow as tf
import os.path as pth
from scipy import io
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'font.size': 12})

In [None]:
base = '/users/jmy/data/image_sound'
label_list = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven',
                'eight', 'nine', 'bed', 'bird', 'cat', 'dog', 'house', 'tree']
imgsize = [224, 224]
vgg_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1,1,3))

In [None]:
data_base_path = pth.join('/data', '01_experiment_data', 'image_sound')

def get_stim_idx(sbj_idx):
    subject_path = pth.join(data_base_path,'prep_new_template', 's'+str(sbj_idx).zfill(2))
    # stim idx
    mat_file = io.loadmat(pth.join(subject_path, 'stimuli', 'run_order.mat'))
    run_list = sorted([name for name in mat_file.keys() if name.startswith('run')])
    image_name_list = [name.strip() for run_name in run_list for name in mat_file[run_name] if name.startswith('i_')]
    temp_list = [(name, i) for i, name in enumerate(image_name_list)]
    result = [sorted([image_name for image_name in temp_list if image_name[0].startswith('i_'+label)], 
                        key=lambda name: int(name[0].lstrip('i_'+label).rstrip('.wav'))) for label in label_list]
    return [name for class_list in result for name, _ in class_list]
def get_imgs(sorted_name_list):
    imgs = np.zeros((192,224,224,3))
    for idx,img_name in enumerate(sorted_name_list):
        img_name = img_name.split('_')[1]
        try:
            path = data_base_path +'/subjects/skku/s'+str(sbj_list[sbj_idx])+'/stimuli/img/'+img_name+'.png'
            img = Image.open(path)
        except:
            path = data_base_path +'/subjects/skku/s'+str(sbj_list[sbj_idx])+'/stimuli/img/'+img_name+'.jpg'
            img = Image.open(path)
        img = img.resize((imgsize[0], imgsize[1]), Image.ANTIALIAS)
        if len(img.size) != 3:
            img = img.convert('RGB')
        img = np.asarray(img).astype('float32')
        img = img - vgg_mean
        imgs[idx] = img[:,:,::-1]
    return imgs
def get_layer_rep(li,imgs,method='Whole'):
    n_imgs = imgs.shape[0]
    op = graph.get_tensor_by_name(relus[li]+':0')
    ft_sh = (n_imgs, op.shape[1], op.shape[2], op.shape[3])
    rep_ = np.zeros((ft_sh),dtype='float32')
    for i,img in enumerate(imgs):
        img = np.expand_dims(img,0)
        rp = sess.run(op, {x:img, is_training:False})
        rep_[i] = rp
    rep_ = rep_.reshape(n_imgs,-1)
    return rep_ # 16,1000 or 16

In [None]:
sbjs_num = 15
sbj_list = [25,26,29,30,31,32,33,34,37,38,39,40,41,43,44]
sbjs_imgs = np.zeros((sbjs_num,192,224,224,3)) # less than 3GB
for sbj_idx in range(sbjs_num):
    sbjs_imgs[sbj_idx] = get_imgs(get_stim_idx(sbj_list[sbj_idx]))

In [None]:
# PCA
savebase = '/users/jmy/data/image_sound/RDMs/CNN/PCA/16x16/'

cdd_n, epoc = 5, 56
    
sess=tf.Session()
mdl_pth = '/users/jmy/data/nets/16_class/VGG_Base/cdd_{}/net-{}.ckpt'.format(str(cdd_n).zfill(2),str(epoc-1))
saver = tf.train.import_meta_graph(mdl_pth+'.meta')
saver.restore(sess, mdl_pth) #tf.train.latest_checkpoint('/users/jmy/data/nets/dualcnn/one/cdd_14'))
graph = tf.get_default_graph()
# placeholder load
x = graph.get_tensor_by_name("Placeholder:0") 
is_training = graph.get_tensor_by_name("Placeholder_2:0")
# get relu layer
relus = [op.name for op in graph.get_operations() if op.type=='Relu']
relus = relus[:14] + ['fc8/Conv2D']

In [None]:
ft_dim_info = [224,224,112,112,56,56,56,28,28,28,14,14,14]
chn_info = [64,64,128,128,256,256,256,512,512,512,512,512,512]

# CNN RDM

In [None]:
depth = range(15)
for li in depth:
    for sbj_idx in range(sbjs_num):
        rep_ = get_layer_rep(li,sbjs_imgs[sbj_idx])
        std_ = StandardScaler()
        x_scld = std_.fit_transform(rep_)
        pca = PCA(n_components=0.90)
        pca.fit(x_scld)
        x_ = pca.transform(x_scld)
        # average to 16x16
        x_16 = np.zeros((16,x_.shape[1]))
        for ci in range(16):
            x_16[ci] = np.mean(x_[ci*12:(ci+1)*12],axis=0)
        rdm = 1-np.corrcoef(x_16.transpose(), rowvar=False)
        filename = '_ind192to16_90_P{}_cdd0{}_L{}.npz'.format(str(sbj_list[sbj_idx]),str(cdd_n),str(li+1).zfill(2))
        savepath = savebase + filename
        np.savez(savepath, rdm=rdm)

# Visualize CNN RDM

In [None]:
# load rdms
rdm_grp = np.zeros((sbjs_num,15,16,16))
for li in depth:
    for sbj_idx in range(sbjs_num):
        filename = '_ind192to16_90_P{}_cdd0{}_L{}.npz'.format(str(sbj_list[sbj_idx]),str(cdd_n),str(li+1).zfill(2))
        savepath = savebase + filename
        rdm_grp[sbj_idx,li] = np.load(savepath)['rdm']
# mask 
mask = np.zeros((16,16))
mask[np.triu_indices_from(mask)] = True
# title list
tlist = []
for li in range(1,14):
    tlist.append('Conv '+str(li))
tlist += ['FC','Output']
# tickslabels
labels = ['Zero','One','Two','Three','Four','Five','Six','Seven','Eight',
          'Nine','Bed','Bird','Cat','Dog','House','Tree']
# default font size
plt.rcParams.update({'font.size': 20})

In [None]:
# average
avg_arr = np.mean(rdm_grp,axis=0)
f,axes = plt.subplots(3,5, figsize=(38,20))
ax = axes.flat
for li in range(15):
    ax[li].set_title(tlist[li],fontsize=40)
    sns.heatmap(rdm_grp[sbj_idx][li],mask=mask,cmap='jet',ax=ax[li],xticklabels=labels,yticklabels=labels,square=True,cbar=True) #,vmin=0,vmax=2
    plt.tight_layout()
plt.show()
plt.close()
f,axes = plt.subplots(3,5, figsize=(38,20))
ax = axes.flat
for li in range(15):
    ax[li].set_title(tlist[li],fontsize=40)
    sns.heatmap(rdm_grp[sbj_idx][li],mask=mask,cmap='jet',ax=ax[li],xticklabels=labels,yticklabels=labels,square=True,vmin=0,vmax=2)
    plt.tight_layout()
plt.show()
plt.close()

In [None]:
plt.rcParams.update({'font.size': 8})

In [None]:
# visualize PCs
depth = range(13)
for li in depth:
    ft_dim = ft_dim_info[li]
    chn = chn_info[li]
    veclen = ft_dim**2*chn
    rep_ = get_layer_rep(li,sbjs_imgs[0])
    std_ = StandardScaler()
    x_scld = std_.fit_transform(rep_)
    pca = PCA(n_components=0.995)
    pca.fit(x_scld)
    pc_num = pca.components_.shape[0]
    print(pc_num)
    if pc_num>16:
        pc_num =4
    f,axes = plt.subplots(pc_num,4,figsize=(10,3*pc_num),constrained_layout=True)
    f.suptitle('Conv '+str(li+1),fontsize=20)
    for pi in range(pc_num):
        tmp = pca.components_[pi].reshape(ft_dim,ft_dim,chn)
        ch_rnd = random.sample(list(np.arange(chn)),4)
        for ci in range(4):
            axes[pi,ci].imshow(tmp[:,:,ch_rnd[ci]])
            axes[pi,ci].set_title('PC '+str(pi+1)+' - Feature '+str(ch_rnd[ci]+1))
            axes[pi,ci].axis('off')
    plt.show()