In [1]:
import os
import json
import hashlib
import random
import numpy as np
import pandas as pd
import xarray
from keras.models import model_from_json

from src.data_loader import Shifted_Data_Loader
from src.metrics.dicarlo import r as dicarlo_r
from src.metrics.dicarlo import dprime as dicarlo_dprime
from src.metrics.dicarlo import selectivity as dicarlo_sel
from src.results.experiments import *
from src.results.file_loaders import *
from src.results.processing import make_xr

import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
    
from collections import Counter
import dit
from dit import Distribution
from sklearn.preprocessing import MinMaxScaler

Using TensorFlow backend.


In [2]:
def load_data_loader(conf):
    np.random.seed(7)
    if conf['ecc_max'] == 0.0:
        tx_max = None
    else:
        tx_max = conf['ecc_max']
        
    if 'rot_max' not in conf.keys() or conf['rot_max'] == 0.0:
        rot_max = None
    else:
        rot_max = conf['rot_max']
    DL = Shifted_Data_Loader(dataset=conf['dataset'],flatten=True,
                             rotation=rot_max,
                             translation=tx_max,num_train=60000,
                            )
    return DL

def load_shifts(conf):
    DL = load_data_loader(conf)
    
    return (DL.dx[1]-14,DL.dy[1]-14,DL.y_test)

def load_train_history(run_dir,conf,filename='train_history.parquet'):
    path = os.path.join(run_dir,filename)
    dirname,fname = os.path.split(path)
#     lab_corruption = np.round(float(dirname.split('/')[-1].split('_')[-1]),decimals=1)
#     arch = dirname.split('/')[-2]
    if conf is None:
        return None
    elif conf['recon'] == 0:
        arch = 'no_recon'
    else:
        arch = 'recon'

    if os.path.exists(path):
        hist = pd.read_parquet(path)
        hist['architecture'] = arch
        hist['label_corruption'] = conf['label_corruption']
        hist['ecc_max'] = conf['ecc_max']
        hist['xent'] = conf['xent']
        hist['recon'] = conf['recon']
        hist['epoch'] = list(hist.index.values*3)
#         hist['val_loss'] = sma(hist['val_loss'].values,win_size=3)
#         hist['loss'] = sma(hist['loss'].values,win_size=3)
        hist['val_dL'] = np.gradient(hist['val_loss'])
        hist['test_err'] = 1-hist['val_class_acc']
        hist['train_err'] = 1-hist['class_acc']
        hist['recon_gen_err'] = hist.G_loss - hist.val_G_loss
        hist['gen_err'] = hist.loss - hist.val_loss
        hist['class_gen_err'] = hist.class_loss - hist.val_class_loss
        hist['class_gen_acc'] = hist.class_acc - hist.val_class_acc

        return hist


def mutual_information(X,Y):
    XY_c = Counter(zip(X,Y))
    XY_pmf = {k:v/float(sum(XY_c.values())) for k,v in XY_c.items()}
    XY_jdist = Distribution(XY_pmf)
        
    return dit.shannon.mutual_information(XY_jdist,[0],[1])

def load_I(rd,fn,conf,feat_range=(0,30)):
    
    dxs,dys,y_test = load_shifts(conf)
    z_enc = np.load(os.path.join(rd,fn))
    z_dim = z_enc.shape[-1]
    z_enc_scaled = [MinMaxScaler(feat_range).fit_transform(z_enc[:,i].reshape(-1,1)).tolist() for i in np.arange(z_dim)]
    z_enc_scaled = np.squeeze(np.array(z_enc_scaled,dtype=int))
    z_dx_I = [mutual_information(z_enc_scaled[i],dxs.astype(int)+14) for i in np.arange(z_dim)]
    z_dy_I = [mutual_information(z_enc_scaled[i],dys.astype(int)+14) for i in np.arange(z_dim)]
    z_class_I = [mutual_information(z_enc_scaled[i],y_test) for i in np.arange(z_dim)]
    z_I_df = pd.DataFrame.from_records({'class':z_class_I,'dy':z_dy_I,'dx':z_dx_I})
    z_I_df['class'] = z_I_df['class'].values.round(decimals=1)
    z_I_df['ecc_max'] = conf['ecc_max']
    z_I_df['recon'] = conf['recon']
    z_I_df['xent'] = conf['xent']
    z_I_df['label_corruption'] = conf['label_corruption']
    
    return z_I_df
#     y_enc = np.load(os.path.join(rd,'y_enc.npy'))
#     y_dim = y_enc.shape[-1]
#     y_enc_scaled = [MinMaxScaler(feat_range).fit_transform(y_enc[:,i].reshape(-1,1)).tolist() for i in np.arange(y_dim)]
#     y_enc_scaled = np.squeeze(np.array(y_enc_scaled,dtype=int))
    
    

In [3]:
# exp_root = '/home/elijahc/projects/vae/models/2019-06-03'
exp_root = '/home/elijahc/projects/vae/models/2019-05-24'
runs = []
for branches in os.listdir(exp_root):
    for leaf in os.listdir(os.path.join(exp_root,branches)):
        runs.append(os.path.join(exp_root,branches,leaf))

runs = list(filter(lambda x: 'ipynb_checkpoints' not in x,runs))
configs = [load_config(rd) for rd in runs]
train_historys = [load_train_history(rd,conf) for rd,conf in zip(runs,configs)]
perf = [load_performance(rd,conf,th) for rd,conf,th in zip(runs,configs,train_historys)]
model_specs = [load_model_spec(rd) for rd in runs]

In [4]:
from keras.models import Model
import hashlib

In [5]:
def raw_to_xr(encodings,l_2_depth,stimulus_set):
    obj_names = [
        "T-shirt",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Dress Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    all_das = []
    for layer,activations in encodings.items():
        neuroid_n = activations.shape[1]
        n_idx = pd.MultiIndex.from_arrays([
            pd.Series(['{}_{}'.format(layer,i) for i in np.arange(neuroid_n)],name='neuroid_id'),
            pd.Series([l_2_depth[layer]]*neuroid_n,name='layer'),
            pd.Series([layer]*neuroid_n,name='region')
        ])
        p_idx = pd.MultiIndex.from_arrays([
            stimulus_set.image_id,
            stimulus_set.dx,
            stimulus_set.dy,
            stimulus_set.numeric_label.astype('int8'),
            pd.Series([obj_names[i] for i in stimulus_set.numeric_label],name='object_name'),
            pd.Series(stimulus_set.dx.values/28, name='tx'),
            pd.Series(stimulus_set.dy.values/28, name='ty'),
            pd.Series([1.0]*len(stimulus_set),name='s'),
        ])
        da = xarray.DataArray(activations.astype('float32'),
                         coords={'presentation':p_idx,'neuroid':n_idx},
                         dims=['presentation','neuroid'])
        all_das.append(da)
        
    return xarray.concat(all_das,dim='neuroid')

In [6]:
DL = load_data_loader(configs[0])
slug = [(dx,dy,float(lab),float(random.randrange(20))) for dx,dy,lab in zip(DL.dx[1],DL.dy[1],DL.y_test)]

pixel_enc = {'pixel':DL.sx_test}


encs = []
stim_sets = []

for mod in map(load_model,runs):
    layer_encodings = []
    layer_names = ['y_lat','z_lat','dense_1','dense_2','dense_3']

    for l in layer_names:
        encoder = Model(mod.input,mod.get_layer(l).output)
        enc = encoder.predict(DL.sx_test)
        layer_encodings.append(enc)
    image_id = [hashlib.md5(json.dumps(list(p),sort_keys=True).encode('utf-8')).digest().hex() for p in slug]
    stim_set = pd.DataFrame({'dx':DL.dx[1]-14,'dy':DL.dy[1]-14,'numeric_label':DL.y_test,'image_id':image_id})
    enc = {k:v for k,v in zip(layer_names,layer_encodings)}
    enc.update(pixel_enc)
    encs.append(enc)
    stim_sets.append(stim_set)


input_shape:  (3136,)
dataset:  fashion_mnist
scale:  2
tx_max:  0.8
rot_max:  None
bg_noise: None
loading fashion_mnist...
sx_train:  (60000, 3136)
making training data...
making testing data...


In [7]:
l_to_depth = {
    'dense_3':3,
    'dense_2':2,
    'dense_1':1,
    'pixel':0,
    'y_lat':4,'z_lat':4}

assemblies = [raw_to_xr(encs[i],l_to_depth,stim_sets[i]) for i in np.arange(3)]

In [8]:
from collections import OrderedDict
def save_assembly(da,run_dir,fname,**kwargs):
    da = da.reset_index(da.coords.dims)
    da.attrs = OrderedDict()
    with open(os.path.join(run_dir,fname), 'wb') as fp:
        da.to_netcdf(fp,**kwargs)

In [9]:
for da,rd in zip(assemblies,runs):
    
    save_assembly(da,run_dir=rd,fname='dataset.nc',
        format='NETCDF3_64BIT',
#         engine=
#         encoding=enc,
       )

In [10]:
y = stim_sets[0].dx.values

In [11]:
exp = 2
X = encs[exp]['z_lat']
Y = stim_sets[exp].dx.values
cut_r = lambda n_idx: np.abs(pearsonr(X[:,n_idx],Y)[0])

zr = list(map(cut_r, np.arange(35)))

IndexError: index 5 is out of bounds for axis 1 with size 5

In [None]:
DL = load_data_loader(configs[0])

In [None]:
mod = load_model(runs[2])
l1_enc = Model(mod.input,mod.get_layer('dense_1').output)
classifier = Model(mod.input,mod.layers[-1].output)
classifier.compile(optimizer='nadam',loss='categorical_crossentropy',metrics=['acc'])
classifier.evaluate(x=DL.sx_test,y=DL.y_test_oh)

In [None]:
perf_df = pd.concat(perf)

In [None]:
def load_variations(conf):
    _,_,y_te = load_shifts(conf)
    return y_te

In [None]:
[print(str(c.keys())+'\n') for c in configs]

In [None]:
shifts = [load_shifts(c) for c in configs]

In [None]:
dxs = [s[0].astype(np.int32) for s in shifts]
dys = [s[1].astype(np.int32) for s in shifts]
y_tes = [s[2].astype(np.int8) for s in shifts]

In [None]:
y_encs = [np.load(os.path.join(rd,'y_enc.npy')).astype(np.float32) for rd in runs]
z_encs = [np.load(os.path.join(rd,'z_enc.npy')).astype(np.float32) for rd in runs]

y_dim = [c['y_dim'] for c in configs]
z_dim = [c['z_dim'] for c in configs]
enc_dim = [c['enc_layers'] for c in configs]
l3_dim = [ed[2] for ed in enc_dim]
l2_dim = [ed[1] for ed in enc_dim]
l1_dim = [ed[0] for ed in enc_dim]

y_enc_dfs = [pd.DataFrame(data=ye,columns=['y_{}'.format(i) for i in np.arange(n_units)]) for ye,n_units in zip(y_encs,y_dim)]
z_enc_dfs = [pd.DataFrame(data=ze,columns=['z_{}'.format(i) for i in np.arange(n_units)]) for ze,n_units in zip(z_encs,z_dim)]

In [None]:
l3_encs = [np.load(os.path.join(rd,'l3_enc.npy')).astype(np.float32) for rd in runs]
l3_enc_dfs = [pd.DataFrame(data=l3e,columns=['l3_{}'.format(i) for i in np.arange(n_units[2])]) for l3e,n_units in zip(l3_encs,enc_dim)]

In [None]:
l2_encs = [np.load(os.path.join(rd,'l2_enc.npy')) for rd in runs]
l2_enc_dfs = [pd.DataFrame(data=l2e,columns=['l2_{}'.format(i) for i in np.arange(n_units[1])]) for l2e,n_units in zip(l2_encs,enc_dim)]

In [None]:
l1_encs = [np.load(os.path.join(rd,'l1_enc.npy')) for rd in runs]
l1_enc_dfs = [pd.DataFrame(data=l1e,columns=['l1_{}'.format(i) for i in np.arange(n_units[0])]) for l1e,n_units in zip(l1_encs,enc_dim)]

In [None]:
def conf_fetch(configs,key):
    return [conf[key] for conf in configs]

In [None]:
l1_enc_dfs[0].columns

In [None]:
# from scipy.stats import gaussian_kde
# from scipy.stats import norm
# Z = norm.ppf

# def dicarlo_dprime(df,num_units=10,col='class'):
#     uniq_cls = np.unique(df[col].values)
#     o = []
#     r = []
#     F = []
#     H = []
#     d = []
#     mask = list(~np.any(df.groupby(col).var().values[:,:num_units]==0,axis=0))
# #     print(len(mask))
#     cols = df.columns[mask+([True]*3)]
# #     print(cols)
#     mdf = df[cols]
# #     print(mdf.head())
#     for i in uniq_cls:
#         oi = mdf[mdf[col]==i].values[:,:len(mdf.columns)-3]
#         o_mu = oi.mean(axis=0)
# #         print(oi.shape)
#         o_kde = [gaussian_kde(oi[:,u]) for u in np.arange(oi.shape[1])]
#         o.append(o_kde)

#         ri = mdf[mdf[col]!=i].values[:,:len(mdf.columns)-3]
#         r_kde = [gaussian_kde(ri[:,u]) for u in np.arange(ri.shape[1])]
#         r.append(r_kde)
#         F.append([k.integrate_box_1d(om,k.dataset.max()) for k,om in zip(r_kde,o_mu)])
#         H.append([k.integrate_box_1d(low,om) for k,om,low in zip(o_kde,o_mu,oi.min(axis=0))])
#         d.append(Z(H[-1])-Z(F[-1]))
    
# #     class_d_max = np.argmax(np.abs(d),axis=0)
#     d = np.abs(d)
#     class_d_max = np.argmax(d,axis=0)
#     d_idxs = (class_d_max,np.arange(len(class_d_max)))
#     d = d[d_idxs]
#     return d,mask

# def dicarlo_sel(df,num_units=10,col='class'):
#     mu = df.groupby(col).mean().values[:,:num_units]
#     var = df.groupby(col).var().values[:,:num_units]

#     mu_max_idxs = np.argmax(mu,axis=0)
#     mu_min_idxs = np.argmin(mu,axis=0)

#     mu_max = np.array([mu[maxi,i] for i,maxi in zip(np.arange(len(mu_max_idxs)),mu_max_idxs)])
#     mu_min = np.array([mu[maxi,i] for i,maxi in zip(np.arange(len(mu_min_idxs)),mu_min_idxs)])
    
#     var_b = np.array([var[maxi,i] for i,maxi in zip(np.arange(len(mu_max_idxs)),mu_max_idxs)])
#     var_w = np.array([var[mini,i] for i,mini in zip(np.arange(len(mu_min_idxs)),mu_min_idxs)])
# #     neg_max = [(allv-mu_max)/9.0 for allv,mu_max in zip(all_vals.sum(axis=1),mu_max)]
    
#     sel = [(mu_b-mu_w)/np.sqrt((vb+vw)/2) for mu_b,mu_w,vb,vw in zip(mu_max,mu_min,var_b,var_w)]
#     return sel

In [None]:
# sns.distplot(o[9][3].dataset,hist=False)
# g = sns.distplot(NS[9][3].dataset,hist=False)
# g.legend(['stim','no-stim'])

obj_names = [
    "T-shirt",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Dress Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
exp_xa = []
for y_te,dx,dy,df,zdf,l3df,l2df,l1df,conf in zip(y_tes,dxs,dys,y_enc_dfs,z_enc_dfs,l3_enc_dfs,l2_enc_dfs,l1_enc_dfs,configs):
    df['class'] = y_te
    zdf['class'] = y_te
    l3df['class'] = y_te
    l2df['class'] = y_te
    l1df['class'] = y_te
    
    df['dx'] = dx
    zdf['dx'] = dx
    l3df['dx'] = dx
    l2df['dx'] = dx
    l1df['dx'] = dx
    
    df['dy'] = dy
    zdf['dy'] = dy
    l3df['dy'] = dy
    l2df['dy'] = dy
    l1df['dy'] = dy
    ydf = df
    
    layer_names = ['y','z','l1','l2','l3']
    depths = [4,4,1,2,3]
    names = ['y','z',None,None,None]
    layer_das = [make_xr(d,depth,region=n) for d,depth,n in zip([ydf,zdf,l1df,l2df,l3df],depths,names)]
    ably = xarray.concat(layer_das,dim='neuroid')
    stim = ably.presentation.to_dataframe().reset_index()
    stim['image_id'] = [hashlib.md5(json.dumps(list(p),sort_keys=True).encode('utf-8')).digest().hex() for p in stim['presentation'].values]
    # Add attributes brain-score expects
    ably = ably.assign_attrs({
#         'hyperparameters': pd.DataFrame([conf]),
        'stimulus_set':stim.drop(columns=['presentation'])
    })
    exp_xa.append(ably)

In [None]:
da = exp_xa[1]
da

In [None]:
exp_conds = [
    'Only XEnt',
    'Only Recon',
    'Both',
]

In [None]:
ds = xarray.Dataset({k:e for k,e in zip(exp_conds,exp_xa)})
# da = make_xr(zdf)
# da

In [None]:
print(os.path.join(runs[2],'recon.nc'))

In [None]:
list(ds.coords.dims.keys())

In [None]:
da = exp_xa[0]
da = da.reset_index(da.coords.dims)
# da = da.assign_attrs(stimulus_set=da.attrs['stimulus_set'].values)
# da.attrs = OrderedDict()
print(da)
fname='dataset.nc'
with open(os.path.join(runs[0],fname), 'wb') as fp:
    da.to_netcdf(fp)

In [None]:
da.attrs['stimulus_set']

In [None]:
# os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

for da,cond,rd in zip(exp_xa,exp_conds,runs):
    
    save_assembly(da,run_dir=rd,fname='dataset.nc',
        format='NETCDF3_64BIT',
#         engine=
#         encoding=enc,
       )

# ds.to_netcdf(encoding=)

In [None]:
def xa_calc_sel(xas,confs,props=['class','dx','dy']):
        ex

In [None]:
def calc_sel(xa,confs,num_units,props):
    exp_conds = {
        'xent' : conf_fetch(confs,'xent'),
        'recon': conf_fetch(confs,'recon'),
        'label_corruption' :conf_fetch(confs,'label_corruption'),
        'ecc_max': conf_fetch(confs,'ecc_max')
    }
    exp_conds = pd.DataFrame.from_records(exp_conds)
    exp_conds_l = [{k:v[i] for k,v in exp_conds.iteritems()} for i in np.arange(len(confs))]
    
    affinities = []
    for p in props:
        if p in ['class']:
            dprimes = [dicarlo_dprime(df,num_units=num_units,col=p) for df in dfs]
            prop_sel = [d[0] for d in dprimes]
            units = [np.where(np.array(d[1])==True)[0] for d in dprimes]
        elif p in ['dx','dy']:
            props = [df[p].values for df in dfs]
#             print([p.shape for p in props])
            activations = [df.values[:,:num_units] for df in dfs]
            prop_sel = [dicarlo_r(a,p) for p,a in zip(props,activations)]
            units = [np.arange(num_units)]*len(confs)
        
        recs = [{'unit':u,'affinity':a,'property':p} for i,u,a in zip(np.arange(len(dfs)),units,prop_sel)]
        for r,cond in zip(recs,exp_conds_l):
            r.update(cond)
            
        prop_dfs = [pd.DataFrame.from_records(r) for r in recs]
#         prop_sel_df = prop_sel_df.join(exp_conds)
#         prop_sel_df['property']=p

        affinities.append(pd.concat(prop_dfs))
        
    return pd.concat(affinities)

In [None]:
layer_szs = list(zip(y_dim,z_dim,l3_dim,l2_dim,l1_dim))

In [None]:
# aff = calc_sel(l3_enc_dfs,confs=configs,num_units=500,props=['class','dx'])
# aff.head()

In [None]:
pops = []
layer_depths = [4,4,3,2,1]
layer_names = ['y','z','l3','l2','l1']
for dfs,n_u in zip([y_enc_dfs,z_enc_dfs,l3_enc_dfs,l2_enc_dfs,l1_enc_dfs],layer_szs[0],):
    print('n_u: ',n_u)
    pop_aff = calc_sel(dfs,confs=configs,num_units=n_u,props=['class','dx','dy'])
    pops.append(pop_aff)
    


In [None]:
pops[0].head()

In [None]:
for df,l_depth,l_name in zip(pops,layer_depths,layer_names):
    print('lname: ',l_name)
#     ldf = df.melt(id_vars=['recon','xent','property','ecc_max','label_corruption'],var_name='unit',value_name='affinity')
    df['layer_depth'] = l_depth
    df['layer_name'] = l_name

In [None]:
all_df = pd.concat(pops)
queries = [
        'label_corruption == 0.0 & xent == 15 & recon == 0 & ecc_max == 0.8',
        'label_corruption == 0.0 & xent == 0 & recon == 25 & ecc_max == 0.8',
        'label_corruption == 0.0 & xent == 15 & recon == 25 & ecc_max == 0.8',
              ]
titles = [
    'Only XEnt',
    'Only Recon',
    'Both'
]
subs = []
for cond,q in zip(titles,queries): 
    sub = all_df.query(q)
    sub['condition'] = [cond]*len(sub)
    subs.append(sub)


In [None]:
all_df = pd.concat(subs)
all_c_df = all_df.query('property == "class"')
all_x_df = all_df.query('property == "dx"')
all_y_df = all_df.query('property == "dy"')
all_pos_df = all_df.query('property == "dx" | property == "dy"').pivot_table(index=['recon','xent','layer_name','unit'],columns=['property'],values='affinity')

In [None]:
all_df.head()

In [None]:
all_pos_df.groupby(['recon','xent','layer_name']).count()

In [None]:
def stream_affinity_profile(dset,prop='class',plot_func=sns.barplot,legend_panel=0,**kwargs):
    fig,axs = plt.subplots(3,1,figsize=(4,12),sharey=True,sharex=True)
    queries = [
        'label_corruption == 0.0 & xent == 15 & recon == 0 & ecc_max == 0.8',
        'label_corruption == 0.0 & xent == 0 & recon == 25 & ecc_max == 0.8',
        'label_corruption == 0.0 & xent == 15 & recon == 25 & ecc_max == 0.8',
              ]
    titles = [
        'Only XEnt',
        'Only Recon',
        'Both'
    ]
#     sns.set_context('talk')
    for ax,q,title in zip(axs,queries,titles): 
        x_order = ['l1','l2','l3']
#         if rowi == 0:
#             x_order.append('y')
#         else:
#             x_order.append('z')
            
        g = plot_func(y='affinity',x='layer_depth',
                  hue='layer_name',
#                       data=dset.query(q).query('layer_name != "{}"'.format(exclude)),
                  data=dset.query(q),ax=ax,**kwargs)
        if title != titles[legend_panel]:
            ax.get_legend().remove()
        
        ax.set_xlabel('')
        ax.set_ylabel('{} selectivity'.format(prop))
        ax.set_title(title)
    
    axs[-1].set_xlabel('Model Layers')
    axs[0].legend(loc='upper center', bbox_to_anchor=(1.5, 1.0), shadow=True, ncol=1)
#         axs[rowi,0].set_ylabel(row_lab.format(prop))
#         axs[rowi,1].set_ylabel('')
#         axs[rowi,2].set_ylabel('')

In [None]:
def stream_affinity_profile_by_layer(dset,prop='class',plot_func=sns.barplot):
    fig,axs = plt.subplots(2,5,figsize=(20,10),sharey=True)
#     queries = [
#         'label_corruption == 0.0 & xent == 15 & recon == 0 & ecc_max == 0.8',
#         'label_corruption == 0.0 & xent == 0 & recon == 25 & ecc_max == 0.8',
#         'label_corruption == 0.0 & xent == 15 & recon == 25 & ecc_max == 0.8',
#               ]
#     queries = ['']
    titles=['Layer {}'.format(l) for l in np.arange(3)+1]
    titles.append('Layer y+z')
    for rowi,row_lab,exclude in zip(np.arange(2),['ventral {} affinity (y)','dorsal {} affinity (z)'],["z","y"]):
        for ax,layer_n,title in zip(axs[rowi],[1,2,3,4],titles): 
#             x_order = [1,2,3,'l4']
            if rowi == 0:
                pass
#                 x_order.append('y')
            else:
                pass
#                 x_order.append('z')
            plot_func(y='affinity',x='condition',
                           hue='condition',
                           data=dset.query('layer_depth == {}'.format(layer_n)),ax=ax)
            ax.set_title(title)
        axs[rowi,0].set_ylabel(row_lab.format(prop))
        axs[rowi,1].set_ylabel('')
        axs[rowi,2].set_ylabel('')

In [None]:
# sns.violinplot()

In [None]:
sns.set_context('poster')
stream_affinity_profile(all_c_df,prop='class',plot_func=sns.boxplot)

# Dx and Dy Selectivity Results by layer

In [None]:
stream_affinity_profile(all_x_df,prop='dx',plot_func=sns.boxplot)

In [None]:
stream_affinity_profile(all_y_df,prop='dy',plot_func=sns.boxplot)

In [None]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(0,1))
rescaled_class = scaler.fit_transform(all_df.query('property == "class"').dropna()['affinity'].values.reshape(-1,1))

norm_class = all_df.query('property == "class"').dropna()[['recon','xent','ecc_max','label_corruption','unit','layer_depth','layer_name']]
norm_class['property']='norm_class'
norm_class['affinity']=np.squeeze(rescaled_class)
# norm_class = pd.DataFrame.from_records({'affinity': np.squeeze(rescaled_class),'property':'norm_class'})
# norm_class.index = all_df.query('property == "class"').dropna()


In [None]:
all_df_ = all_df.append(norm_class)

In [None]:
all_df = all_df_
ydf = all_df.query('layer_name == "y"').reset_index()
ydf = pd.pivot_table(ydf,index=['xent','recon','unit'],columns='property',values='affinity')
ydf['layer_name']= 'y'

zdf = all_df.query('layer_name == "z"').reset_index()
zdf = pd.pivot_table(zdf,index=['xent','recon','unit'],columns='property',values='affinity')
zdf['layer_name']= 'z'

l4df = pd.concat([ydf,zdf])

l3df = all_df.query('layer_name == "l3"').reset_index()
l3df = pd.pivot_table(l3df,index=['xent','recon','unit'],columns='property',values='affinity')
l3df['layer_name']= 'l3'


l2df = all_df.query('layer_name == "l2"').reset_index()
l2df = pd.pivot_table(l2df,index=['xent','recon','unit'],columns='property',values='affinity')
l2df['layer_name']= 'l2'


l1df = all_df.query('layer_name == "l1"').reset_index()
l1df = pd.pivot_table(l1df,index=['xent','recon','unit'],columns='property',values='affinity')
l1df['layer_name']= 'l1'

# ydf['dx_selectivity'] = y_x_sel_df['dx_selectivity']
# zdf = z_c_sel_df
# zdf['dx_selectivity'] = z_x_sel_df['dx_selectivity']
# l3df = l3_c_sel_df
# l3df['dx_selectivity'] = l3_x_sel_df['dx_selectivity']

In [None]:
def pop_scatter(x='dx',y='dy',dsets=[l1df,l2df,l3df,l4df,],plot_func=sns.scatterplot,**plot_kw):
    filts =[
        'recon == 0 & xent == 15',
        'recon == 25 & xent == 0',
        'recon == 25 & xent == 15',
        
    ]
    fig,axs = plt.subplots(len(dsets),3,figsize=(2.5*3,2.5*len(dsets),),
    #                        subplot_kw={'xlim':(-.1,1),'ylim':(-0.1,1)},
                           sharex=True,sharey=True)
    objective_conds = ['Only Xent','Only Recon','Xent+Recon']
    for filt,coli,obj in zip(filts,np.arange(3),objective_conds):
        for ax,df,rowi in zip(axs[:,coli],dsets,np.arange(len(dsets))):
            plot_func(x=x,y=y,data=df.query(filt),ax=ax,**plot_kw)
            if coli == 0:
                ax.set_ylabel('Layer {}'.format(rowi+1))
            
        axs[0,coli].set_title(obj)
# sns.scatterplot(x=xvar,y=yvar,data=ydf.query(filt),ax=axs[1])
# sns.scatterplot(x=xvar,y=yvar,data=zdf.query(filt),ax=axs[2])

In [None]:
def pop_kde(x='dx',y='dy',dsets=[l1df,l2df,l3df,l4df,],clip=(0,1),rug_plot=False,**plot_kw):
    filts =[
        'recon == 0 & xent == 15',
        'recon == 25 & xent == 0',
        'recon == 25 & xent == 15',
        
    ]
    fig,axs = plt.subplots(3,len(dsets),figsize=(3*len(dsets),3*3,),
                           subplot_kw={'xlim':(0,1),'ylim':(0,1)},
                           sharex=True,sharey='row')
    objective_conds = ['Only Xent','Only Recon','Xent+Recon']
    layer_names = ['1','2','3','y+z',]
    for coli,df,layer_lab in zip(np.arange(len(dsets)),dsets,layer_names):
        for filt,ax,obj in zip(filts,axs[:,coli],objective_conds):
            fdf = df.query(filt).dropna()
            if coli != len(dsets)-1:
                sns.kdeplot(data=fdf[x].values,data2=fdf[y].values,ax=ax,clip=clip,**plot_kw)
            else:
                ax.scatter(fdf.query('layer_name == "y"')[x].values,fdf.query('layer_name == "y"')[y].values,marker='.',**plot_kw)
                ax.scatter(fdf.query('layer_name == "z"')[x].values,fdf.query('layer_name == "z"')[y].values,marker='.',c='orange',**plot_kw)

#                 ax.set_ylim(0,)
                
            if layer_lab == layer_names[-1] and rug_plot:
                sns.rugplot(fdf.query('layer_name == "y"')[x].values,ax=ax)
                sns.rugplot(fdf.query('layer_name == "z"')[x].values,ax=ax,color="g")
                sns.rugplot(fdf.query('layer_name == "y"')[y].values,ax=ax,vertical=True)
                sns.rugplot(fdf.query('layer_name == "z"')[y].values,ax=ax,color="g",vertical=True)
#                 ax.scatter(fdf[x].values,fdf[y].values,c="w",marker="+",linewidths=1)
            if coli == 0:
                ax.set_ylabel('{}'.format(obj))
            
            if obj == objective_conds[-1]:
                ax.set_xlabel(x)
            
        axs[0,coli].set_title('Layer {}'.format(layer_lab))
#         axs[0,coli].set_ylabel()

In [None]:
# l3df.query(f).dropna().head()

In [None]:
# sns.kdeplot()

In [None]:
pop_kde(
#     clip=(-0.1,0.75),
#     shade=True,shade_lowest=False,
)
# plt.xlim(0,0.6)
# plt.ylim(0,0.6)

In [None]:
pop_kde(x='norm_class',y='dx',
#         clip=(-0.1,1),
#         shade=True,shade_lowest=False
       )


In [None]:
# sns.kdeplot()

In [None]:
filt = 'recon == 25 & xent == 15'
xvar = 'norm_class'
yvar = 'dx'

g = sns.JointGrid(x=xvar,y=yvar,data=l4df.query(filt),space=0.5,
#                   xlim=(-0.05,0.5),
#                   ylim=(-0.05,0.5)
                 )
g = g.plot_joint(sns.kdeplot,
                 shade=True,shade_lowest=False,
                )
# g.ax_joint.set_xlim(-0.05,0.5)
# g.ax_joint.set_ylim(-0.05,0.9)
# g = g.plot_marginals(sns.kdeplot)
sns.kdeplot(l4df.query('{} & layer_name == "y"'.format(filt))[xvar],ax=g.ax_marg_x,legend=None)
sns.kdeplot(l4df.query('{} & layer_name == "z"'.format(filt))[xvar],ax=g.ax_marg_x,c="orange",legend=None)
sns.kdeplot(l4df.query('{} & layer_name == "y"'.format(filt))[yvar],ax=g.ax_marg_y,legend=None,vertical=True)
sns.kdeplot(l4df.query('{} & layer_name == "z"'.format(filt))[yvar],ax=g.ax_marg_y,c="orange",vertical=True)
g.ax_marg_y.legend(['y','z'],loc='upper center', bbox_to_anchor=(1.7, 1.1), shadow=True, ncol=1)

g.ax_marg_x.set_title('Both')
# sns.rugplot(l4df.query('{} & layer_name == "y"'.format(filt))[yvar],ax=g.ax_joint,vertical=True)
# sns.rugplot(l4df.query('{} & layer_name == "z"'.format(filt))[yvar],ax=g.ax_joint,c="g",vertical=True)

In [None]:
pop_kde(y='dy',x='norm_class')

In [None]:
pop_scatter(x='dx',y='norm_class',hue='layer_name',legend=None)

In [None]:
pop_scatter(xvar='dx',yvar='norm_class',hue='layer_name')

In [None]:
pop_scatter(xvar='dy',yvar='class')

In [None]:
fig,axs = plt.subplots(1,3,figsize=(15,5),
#                        subplot_kw={'xlim':(0,3),'ylim':(0,3)},
                       sharex=True,sharey=True)

filt = 'recon == 25 & xent == 0'
xvar = 'dx'
yvar = 'dy'
sns.scatterplot(x=xvar,y=yvar,data=l3df.query(filt),ax=axs[0])
sns.scatterplot(x=xvar,y=yvar,data=ydf.query(filt),ax=axs[1])
sns.scatterplot(x=xvar,y=yvar,data=zdf.query(filt),ax=axs[2])

In [None]:
fig,axs = plt.subplots(1,3,figsize=(15,5),
#                        subplot_kw={'xlim':(0,3),'ylim':(0,3)},
                       sharex=True,sharey=True)

filt = 'recon == 25 & xent == 15'
xvar = 'dx'
yvar = 'dy'
sns.scatterplot(x=xvar,y=yvar,data=l3df.query(filt),ax=axs[0])
sns.scatterplot(x=xvar,y=yvar,data=ydf.query(filt),ax=axs[1])
sns.scatterplot(x=xvar,y=yvar,data=zdf.query(filt),ax=axs[2])

In [None]:
# ydf[['recon','xent','unit']]
ydf.head()

In [None]:
fig,axs = plt.subplots(1,3,figsize=(15,5),
#                        subplot_kw={'xlim':(0,3),'ylim':(0,3)},
                       sharex=True,sharey=True)
sns.scatterplot(x='dx',y='class',data=l3df.reset_index().query('recon == 0 & xent == 15'),ax=axs[0])
sns.scatterplot(x='dx',y='class',data=ydf.query('recon == 0 & xent == 15'),ax=axs[1])
sns.scatterplot(x='dx',y='class',data=zdf.query('recon == 0 & xent == 15'),ax=axs[2])

# plt.xlim(0.0,3)
# plt.ylim(0.0,3)
axs[0].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[0].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[0].set_title('l3 latent space')
axs[1].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[1].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[1].set_title('y latent space')
axs[2].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[2].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[2].set_title('z latent space')

In [None]:
fig,axs = plt.subplots(1,3,figsize=(15,5),
#                        subplot_kw={'xlim':(0,3),'ylim':(0,3)},
                       sharex=True,sharey=True)
sns.scatterplot(x='dx',y='class',data=l3df.query('recon == 25 & xent == 15'),ax=axs[0])
sns.scatterplot(x='dx',y='class',data=ydf.query('recon == 25 & xent == 15'),ax=axs[1])
sns.scatterplot(x='dx',y='class',data=zdf.query('recon == 25 & xent == 15'),ax=axs[2])

axs[0].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[0].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[0].set_title('l3 latent space')
axs[1].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[1].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[1].set_title('y latent space')
axs[2].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[2].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[2].set_title('z latent space')

In [None]:
fig,axs = plt.subplots(1,3,figsize=(15,5),
#                        subplot_kw={'xlim':(0,3),'ylim':(0,3)},
                       sharex=True,sharey=True)
sns.scatterplot(x='dx',y='class',data=l3df.query('recon == 25 & xent == 0'),ax=axs[0])
sns.scatterplot(x='dx',y='class',data=ydf.query('recon == 25 & xent == 0'),ax=axs[1])
sns.scatterplot(x='dx',y='class',data=zdf.query('recon == 25 & xent == 0'),ax=axs[2])

# plt.xlim(0.0,3)
# plt.ylim(0.0,3)
axs[0].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[0].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[0].set_title('l3 latent space')
axs[1].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[1].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[1].set_title('y latent space')
axs[2].hlines(y=0.66,xmin=0,xmax=2,linestyles='dashed')
axs[2].vlines(x=0.66,ymin=0,ymax=2,linestyles='dashed')
axs[2].set_title('z latent space')

In [None]:
sns.boxplot(y='class_selectivity',x='xent',hue='ecc_max',data=y_sel_df.query('label_corruption == 0.0 & recon == 25'),palette='Purples')

In [None]:
sns.stripplot(y='class_selectivity',x='ecc_max',hue='xent',data=l2_sel_df.query('label_corruption == 0.0'),palette='Purples')

In [None]:
sns.boxplot(y='class_selectivity',x='recon',hue='ecc_max',data=y_sel_df.query('label_corruption == 0.0'))

In [None]:
sns.boxplot(y='class_selectivity',x='recon',hue='ecc_max',data=z_sel_df.query('label_corruption == 0.0'))

In [None]:
sns.boxplot(y='class_selectivity',x='ecc_max',hue='recon',data=z_sel_df.query('label_corruption == 0.0'))

In [None]:
conf_fetch(configs,'xent')

In [None]:
len(configs)