- To visualize the unit activations from CNN layers.

In [None]:
import os
import numpy as np
import tensorflow as tf
from sklearn.manifold import TSNE

import os.path as pth
from PIL import Image
from scipy import io
import random

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4" # "0,1"

In [None]:
base = '/users/jmy/data/image_sound'
label_list = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven',
                'eight', 'nine', 'bed', 'bird', 'cat', 'dog', 'house', 'tree']
imgsize = [224, 224]
vgg_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape((1,1,3))

In [None]:
data_base_path = pth.join('/data', '01_experiment_data', 'image_sound')
def get_stim_idx(sbj_idx):
    subject_path = pth.join(data_base_path,'prep_new_template', 's'+str(sbj_idx).zfill(2))
    # stim idx
    mat_file = io.loadmat(pth.join(subject_path, 'stimuli', 'run_order.mat'))
    run_list = sorted([name for name in mat_file.keys() if name.startswith('run')])
    image_name_list = [name.strip() for run_name in run_list for name in mat_file[run_name] if name.startswith('i_')]
    temp_list = [(name, i) for i, name in enumerate(image_name_list)]
    result = [sorted([image_name for image_name in temp_list if image_name[0].startswith('i_'+label)], 
                        key=lambda name: int(name[0].lstrip('i_'+label).rstrip('.wav'))) for label in label_list]
    return [name for class_list in result for name, _ in class_list]
def get_imgs(sorted_name_list):
    imgs = np.zeros((192,224,224,3))
    for idx,img_name in enumerate(sorted_name_list):
        img_name = img_name.split('_')[1]
        try:
            path = data_base_path +'/subjects/skku/s'+str(sbj_list[sbj_idx])+'/stimuli/img/'+img_name+'.png'
            img = Image.open(path)
        except:
            path = data_base_path +'/subjects/skku/s'+str(sbj_list[sbj_idx])+'/stimuli/img/'+img_name+'.jpg'
            img = Image.open(path)
        img = img.resize((imgsize[0], imgsize[1]), Image.ANTIALIAS)
        if len(img.size) != 3:
            img = img.convert('RGB')
        img = np.asarray(img).astype('float32')
        img = img - vgg_mean
        imgs[idx] = img[:,:,::-1]
    return imgs
def get_layer_rep(li,imgs,method=None):
    n_imgs = imgs.shape[0]
    op = graph.get_tensor_by_name(relus[li]+':0')
    ft_sh = (n_imgs, op.shape[1], op.shape[2], op.shape[3])
    rep_ = np.zeros((ft_sh),dtype='float32')
    for i,img in enumerate(imgs):
        img = np.expand_dims(img,0)
        rp = sess.run(op, {x:img, is_training:False})
        rep_[i] = rp
    rep_ = rep_.reshape(n_imgs,-1)
    if method == 'Rnd':
        if li < 14:
            random.seed(li)
            rnd_idx = random.sample(list(np.arange(rep_.shape[1])),smp_num_list[li])
            rep_ = rep_[:,rnd_idx]
    return rep_

In [None]:
smp_num_list = np.ones((15),dtype=int)*1000
smp_num_list[-1] = 16

In [None]:
cdd_n, epoc = 5, 56
cnn_v = ''
sess = tf.Session()
mdl_pth = '/users/jmy/data/nets/16_class/VGG_Base/cdd_{}{}/net-{}.ckpt'.format(str(cdd_n).zfill(2),cnn_v,str(epoc-1))
saver = tf.train.import_meta_graph(mdl_pth+'.meta')
saver.restore(sess, mdl_pth) #tf.train.latest_checkpoint('/users/jmy/data/nets/dualcnn/one/cdd_14'))
graph = tf.get_default_graph()
# placeholder load
x = graph.get_tensor_by_name("Placeholder:0") 
is_training = graph.get_tensor_by_name("Placeholder_2:0")
# get relu layer
relus = [op.name for op in graph.get_operations() if op.type=='Relu']
relus = relus[:14] + ['fc8/Conv2D']

In [None]:
def scale_01(x):
    rg = (np.max(x) - np.min(x))
    st = x - np.min(x)
    return st / rg

In [None]:
sbjs_num = 15
sbj_list = [25,26,29,30,31,32,33,34,37,38,39,40,41,43,44]
sbjs_imgs = np.zeros((sbjs_num,192,224,224,3)) # less than 3GB
for sbj_idx in range(sbjs_num):
    sbjs_imgs[sbj_idx] = get_imgs(get_stim_idx(sbj_list[sbj_idx]))

In [None]:
labels = []
for sbj_idx in range(sbjs_num):
    for ci in range(16):
        for si in range(12):
            labels.append(ci)

In [None]:
color_list = ['#E07F81','#AC84C8','#97C9EC','#7CBC75','#EDEC2E','#E1A822','#DA592C','#8A319B','#3D59B1','#6294D3','#82BEA6','#4C888E','#418335','#97C225','#9A9A33','#724821']#'#ACA7AC','#5E5226','#632A2C']
tlist = []
for li in range(1,14):
    tlist.append('Conv '+str(li))
tlist += ['FC','Output']

In [None]:
for li in range(15):
    if li < 14:
        nn = 1000
    else:
        nn = 16
    features = np.zeros((15,192,nn))
    for sbj_idx in range(sbjs_num):
        imgs = sbjs_imgs[sbj_idx]
        features[sbj_idx] = get_layer_rep(li,imgs,'Rnd') # (192x15,1000):imagesxFC nodes  
    features = features.reshape(15*192,nn)    
    tsne = TSNE(n_components=2,perplexity=30).fit_transform(features)

    tx = tsne[:,0]
    ty = tsne[:,1]

    tx = scale_01(tx)
    ty = scale_01(ty)
    fig = plt.figure(figsize=(8,8))
    plt.title(tlist[li],fontsize=15)

    for ci in range(16):
        indices = [i for i, l in enumerate(labels) if l == ci]
        curr_tx = np.take(tx, indices)
        curr_ty = np.take(ty, indices)

        plt.scatter(curr_tx, curr_ty, c=color_list[ci], label=label_list[ci])
    plt.axis('off')
    plt.show()
    plt.close()
    
fig = plt.figure(figsize=(10,7))

for ci in range(16):
    plt.scatter([1,2], [1,2], c=color_list[ci], label=label_list[ci])
plt.legend(loc='best',fontsize=15)
plt.show()

In [None]:
color_list = ['#E07F81','#AC84C8','#97C9EC','#7CBC75','#EDEC2E','#E1A822','#DA592C','#8A319B','#3D59B1','#6294D3','#82BEA6','#4C888E','#418335','#97C225','#9A9A33','#724821']#'#ACA7AC','#5E5226','#632A2C']
imglen = 224*224*3
features = np.zeros((15,192,imglen))
for sbj_idx in range(sbjs_num):
    features[sbj_idx] = sbjs_imgs[sbj_idx].reshape(192,imglen)
features = features.reshape(15*192,imglen)    
tsne = TSNE(n_components=2,perplexity=30).fit_transform(features)

tx = tsne[:,0]
ty = tsne[:,1]

tx = scale_01(tx)
ty = scale_01(ty)
fig = plt.figure(figsize=(8,8))
plt.title('input',fontsize=15)

for ci in range(16):
    indices = [i for i, l in enumerate(labels) if l == ci]
    curr_tx = np.take(tx, indices)
    curr_ty = np.take(ty, indices)

    plt.scatter(curr_tx, curr_ty, c=color_list[ci], label=label_list[ci])
plt.axis('off')
plt.show()
plt.close()