In [None]:
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
plt.rcParams.update({'font.size': 12})

In [None]:
# RDM codes for each condition
# Digit recognition code
dgt_cd = np.zeros((16,16))
dgt_cd[:10,:10] = 1

# Object recognition code
obj_cd = np.zeros((16,16))
obj_cd[10:,10:] = 1

# Digit vs. Object code
vs_cd = np.zeros((16,16))
vs_cd[:10,10:] = 1
vs_cd[10:,:10] = 1

# Magnitude code
mag_cd = np.zeros((16,16))
mag_cd[7:10,1:4] = 1

# Animacy code
ani_cd = np.zeros((16,16))
ani_cd[13:,10:13] = 1

cds_list = [dgt_cd,obj_cd,vs_cd,mag_cd,ani_cd]
num_cds = len(cds_list)

In [None]:
# Load CNN RDMs 
depth = 15
cnnrdms = np.load('Sample_CNN_RDM.npz')['cnnrdms'] # This sample data was saved from '04_RSA_From_CNN_Perspective' 

In [None]:
# Preparing plot
sns.set(style='dark',font_scale=1.3)
col_list = sns.color_palette('Spectral_r',depth)
col_arr = np.array(col_list)
xticks = np.arange(depth)
xlabels=['1','2','3','4','5','6','7','8','9','10','11','12','13','Fc','Out\nput']
tt_list = ['Digit recognition','Object recognition','Digit vs. Object','Magnitude of digits','Animacy of objects'] # Title list

In [None]:
# Define function
def rdm_reord(org):
    new = org.copy()
    new[10:15,10:15] = org[11:,11:]
    new[15,15] = org[10,10]
    new[-1,10:15] = org[10,11:]
    new[10:15,-1] = org[11:,10]
    return new

def cos_sim(A, B): # Calculate cosine similarity
    return np.dot(A, B)/(norm(A)*norm(B))
    
def cnn_rdm_cos(cd, cnnrdms):
    summ_arr = np.zeros((depth))
    il = np.tril_indices(16,-1)
    for li, layrdm in enumerate(cnnrdms):
        summ_arr[li] = cos_sim(cd[il],rdm_reord(layrdm)[il])
    return summ_arr

def plot_corr_bar(summ_arr):
    f, axes = plt.subplots(1,1,figsize=(6,4),facecolor='white',constrained_layout=True)
    axes.set_title(tt_list[ci],fontsize=20)
    axes.bar(xticks,summ_arr,color=col_arr,width=0.9,error_kw=dict(lw=1.2),ecolor='k',linewidth=1.5)
    axes.set_xticks(xticks)
    axes.set_xticklabels(xlabels)
    axes.set_xlabel('CNN layer')
    axes.set_ylabel('Cosine similarity')
    plt.show()

In [None]:
for ci,cd in enumerate(cds_list):
    summ_arr = cnn_rdm_cos(cd, cnnrdms)
    plot_corr_bar(summ_arr)