In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from jsputils import classes, feature_extractor
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import gc
import scipy.stats as stats
import pandas as pd
from IPython.core.debugger import set_trace
from fastprogress import progress_bar
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import MDS
from scipy.spatial.distance import squareform, pdist

from matplotlib.collections import LineCollection
from mpl_toolkits import mplot3d
%matplotlib inline

In [None]:
model_name = 'alexnet-barlow-twins'
probe_imageset_name = 'mc8-shined'

In [None]:
DNN = classes.DNNModel(model_name)

probe = classes.ImageSet(probe_imageset_name, transforms = DNN.transforms)

In [None]:
DNN.get_floc_features(probe, field = 'probe_features', device = 'cuda:0', invert = False)

In [None]:
def PCA_MDS_experiment(layers_to_include, img_domains, nPCs, linewidths, drawplots = True):
    
    seed = 0
    layers = ['conv1','conv3','conv5','fc7']
    
    layer_PCs = []
    layer_PC_rdms = []
    incl_idx = []
    
    # STEP 1: get indices of images to include in analysis
    for dom in img_domains:
        incl_idx.append(range(30*dom, 30*dom+30))
    incl_idx = np.concatenate(incl_idx)

    # STEP 2: run PCA on saved activation matrices
    for layer in progress_bar(layers):
        
        flag = False
        for lay_ in layers_to_include:
            if lay_ in layer:
                flag = True

        if flag is True: 
            
            X = DNN.probe_features[layer]
            if np.ndim(X) == 4:
                X = X.reshape((X.shape[0], np.prod(X.shape[1:])))
                
            X = X[incl_idx]
            pca = PCA(n_components=nPCs,random_state=seed)
            X_PCs = pca.fit_transform(X)

            layer_PCs.append(X_PCs) # save the PCs
            
    Y = np.vstack(layer_PCs)

    # STEP 3: create an RDM over the matrix of all PCs from all imgs/layers 
    uberRDM = squareform(pdist(Y,'correlation'))
    #rsa.rdm(Y)

    if drawplots is True:
        plt.figure(figsize=(12,8))
        plt.subplot(121)
        plt.imshow(Y,aspect='auto')
        plt.clim([-100,100])
        plt.colorbar()
        plt.title('PCs from all imgs/layers')
        plt.subplot(122)
        plt.imshow(uberRDM)
        plt.colorbar()
        plt.title('uberRDM over all img/layer PCs')

    # STEP 4: perform MDS 
    mds = MDS(n_components=2,dissimilarity='precomputed',random_state=seed)
    mds_coords = mds.fit_transform(uberRDM)
    
    # STEP 5: plot trajectories
    
    if drawplots is True:

        categ_colors = probe.domain_colors

        categ_colors = categ_colors[np.array(img_domains)]
        
        nimg = X.shape[0]

        fig,a = plt.subplots(figsize=(15,15))
        lwidths = linewidths 
        for i in range(0,nimg,1):
            indices = np.arange(i,mds_coords.shape[0],nimg)
            subset = mds_coords[indices,:]

            plt.scatter(subset[0,0], subset[0,1], 60, c=categ_colors[int(i/30)].reshape(1,3))
            plt.scatter(subset[-1,0], subset[-1,1], 240, c=categ_colors[int(i/30)].reshape(1,3))

            points = np.array([subset[:,0], subset[:,1]]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, linewidths=lwidths,color=categ_colors[int(i/30)],alpha=0.3)
            a.add_collection(lc)
            a.set_xlim(-1,1)
            a.set_ylim(-1,1)

        fig.show()
    
    return mds_coords

In [None]:
img_domains = [0,3,2,6]
layers_to_include = ['conv1','conv3','conv5','fc7']
nPCs = 10
linewidths = [1, 2, 5, 8]

mds_coords = PCA_MDS_experiment(layers_to_include, 
                   img_domains, 
                   nPCs, 
                   linewidths,
                   drawplots = True)


In [None]:
%matplotlib inline
categ_colors = probe.domain_colors

categ_colors = categ_colors[np.array(img_domains)]
nimg = 30*len(img_domains)

fig = plt.subplots(figsize=(20,20))
plt.axis('off')
plt.axis(False)
plt.box('off')
a = plt.axes(projection="3d")
a.set_box_aspect((2.25,1,1))  # aspect ratio is 1:1:1 in data space

linewidths = [1, 2, 3, 4]
xpos = [0,1, 2, 3]

for i in range(nimg):
    indices = np.arange(i,mds_coords.shape[0],nimg)
    subset = mds_coords[indices,:]
    
    # every layer gets its own x coordinate
    for j in range(subset.shape[0]):
        
        # plot dots
        a.scatter3D(xpos[j], subset[j,0], subset[j,1], c=categ_colors[int(i/30)].reshape(1,3), s=40*j+10)
        
        if j < subset.shape[0]-1:
            
            # plot connecting lines
            a.plot3D([xpos[j], xpos[j+1]], [subset[j,0], subset[j+1,0]], [subset[j,1], subset[j+1,1]],
                    color=categ_colors[int(i/30)],alpha=0.4,linewidth=linewidths[j])
            
# make the panes transparent
a.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
a.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
a.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# make the grid lines transparent
a.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
a.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
a.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)
plt.axis('off')
a.set_axis_off()
            
for ii in range(0,360,10):
    a.view_init(elev=0, azim=ii)
    plt.savefig("figure_outputs/Figure5-Content-Channeling/channels_%d_deg.png" % ii)