In [13]:
import os
import random
import json
import hashlib
import numpy as np
import pandas as pd
import xarray
import neptune
import matplotlib.pyplot as plt
import seaborn as sns
from keras.models import Model
from tqdm import tqdm as tqdm
import multiprocessing as mp

from sklearn.linear_model import LinearRegression,Ridge,RidgeCV

import brainscore
from brainscore.assemblies import walk_coords,split_assembly
from brainscore.assemblies import split_assembly
from brainscore.metrics import Score

from brainio_base.assemblies import DataAssembly

from scipy.stats import pearsonr

from src.results.experiments import _DateExperimentLoader
from src.results.utils import raw_to_xr, dprime
from src.results.neptune import get_model_files, load_models, load_assemblies, load_params, load_properties,create_assemblies
from src.data_loader import Shifted_Data_Loader
from src.data_generator import ShiftedDataBatcher

def set_style():
    # This sets reasonable defaults for font size for
    # a figure that will go in a paper
    sns.set_context("paper")
    
    # Set the font to be serif, rather than sans
    sns.set(font='serif')
    
    # Make the background white, and specify the
    # specific font family
    sns.set_style("white", {
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Palatino", "serif"]
    })

In [2]:
os.environ['NEPTUNE_API_TOKEN']="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5tbCIsImFwaV9rZXkiOiI3ZWExMTlmYS02ZTE2LTQ4ZTktOGMxMi0wMDJiZTljOWYyNDUifQ=="
neptune.init('elijahc/DuplexAE')
neptune.set_project('elijahc/DuplexAE')

Project(elijahc/DuplexAE)

In [8]:
proj_root = '/home/elijahc/projects/vae'

In [9]:
exps = neptune.project.get_experiments(id=['DPX-10','DPX-16'])
# mod = next(load_models(proj_root,exps))

In [11]:
e = exps[0]

In [14]:
exp_dir = os.path.join(proj_root,e.get_properties()['dir'])
PARAMS = e.get_parameters()
PROPS = e.get_properties()

In [15]:
DB = ShiftedDataBatcher(PROPS['dataset'],translation=PARAMS['im_translation'],bg=PARAMS['bg'],
                        blend=None,
#                         blend='difference',
                        batch_size=PARAMS['batch_sz'],
                       )

In [None]:
DL = Shifted_Data_Loader('fashion_mnist',rotation=None,translation=0.75,bg='natural',flatten=False)
sx_test = DL.sx_test

In [None]:
plt.imshow(sx_test[250].reshape(56,56),cmap='gray')

In [None]:
slug = [(dx,dy,float(lab),float(rxy)) for dx,dy,rxy,lab in zip(DL.dx[1]-14,DL.dy[1]-14,DL.dtheta[1],DL.y_test)]
# stim_set = pd.DataFrame({'dx':DL.dx[1]-14,'dy':DL.dy[1]-14,'numeric_label':DL.y_test,'rxy':DL.dtheta[1],'image_id':image_id})

ca = create_assemblies(proj_root,exps,test_data=sx_test,slug=slug)

In [None]:
# lg_both = xrs[0]
# lg_xent = xrs[1]

In [None]:
# exps[0].get_properties()['dir']
# mod_dir = os.path.join(proj_root,exps[0].get_properties()['dir'])
# save_assembly(out,run_dir=mod_dir,fname='dataset.nc',
#     format='NETCDF3_64BIT',
# )

In [None]:
# das = load_assemblies(proj_root,exps)
# lg_both = next(das)

In [None]:
# lg_xent = lg.assemblies[0]
# lg_both = lg.assemblies[1]

In [None]:
neural_data = brainscore.get_assembly(name="dicarlo.Majaj2015")
neural_data.load()
stimulus_set = neural_data.attrs['stimulus_set']

In [None]:
def process_dicarlo(assembly,avg_repetition=True,variation=3,tasks=['ty','tz','rxy']):
    stimulus_set = assembly.attrs['stimulus_set']
    stimulus_set['dy_deg'] = stimulus_set.tz*stimulus_set.degrees
    stimulus_set['dx_deg'] = stimulus_set.ty*stimulus_set.degrees
    stimulus_set['dy_px'] = stimulus_set.dy_deg*32
    stimulus_set['dx_px'] = stimulus_set.dx_deg*32
    
    assembly.attrs['stimulus_set'] = stimulus_set
    
    data = assembly.sel(variation=variation)
    groups = ['category_name', 'object_name', 'image_id']+tasks
    if not avg_repetition:
        groups.append('repetition')
        
    data = data.multi_groupby(groups)     # (2)
    data = data.mean(dim='presentation')
    data = data.squeeze('time_bin')    #   (3)
    data.attrs['stimulus_set'] = stimulus_set.query('variation == {}'.format(variation))
    data = data.T
    
    return data

In [None]:
# pos_samples = .where(lg_both.numeric_label==1,drop=True).values[:,250]
# neg_samples = lg_both.where(lg_both.numeric_label!=1, drop=True).values[:,250]

In [None]:
def xr_exclude_zero_dim(da,neuroid_coord):
    nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
    return da[:,nz_neuroids]

In [None]:
def SUCorrelation(da,neuroid_coord,correlation_vars,exclude_zeros=True):
    if exclude_zeros:
        nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
        da = da[:,nz_neuroids]
    
    correlations = np.empty((len(da[neuroid_coord]),len(correlation_vars)))
    for i,nid in tqdm(enumerate(da[neuroid_coord].values),total=len(da[neuroid_coord])):
        for j,prop in enumerate(correlation_vars):
            n_act = da.sel(**{neuroid_coord:nid}).squeeze()
            r,p = pearsonr(n_act,prop)
            correlations[i,j] = np.abs(r)

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',[v.name for v in correlation_vars])
#     print(neuroid_dim)
    result = Score(correlations,
                       coords=c,
                       dims=('neuroid','task'))
    return result

def SUDprime(da,neuroid_coord='neuroid_id',class_coord='numeric_label',exclude_zeros=True):    
    if exclude_zeros:
            nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
            da = da[:,nz_neuroids]
    
    def cat_parts(da,class_coord):
        out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
        return np.array(out)

    def dprime_1d(vec,cut=1000):
        return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)
    
#     class_vals = np.unique(da[class_coord].values)
#     parts = [((da[class_coord]==c).values,(da[class_coord]!=c).values) for c in class_vals]
    class_vals = np.unique(da[class_coord].values)

    c_parts = cat_parts(da,class_coord)
    
    dprimes = np.empty((len(da[neuroid_coord]),len(class_vals)))
    for i,nid in tqdm(enumerate(da[neuroid_coord].values),total=dprimes.shape[0]):
#         da_n = da.sel(**{neuroid_coord:nid})
        dpn = np.apply_along_axis(dprime_1d,1,c_parts[:,:,i])
        dprimes[i] = dpn

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',['category'])
#     print(neuroid_dim)
    result = Score(dprimes.max(axis=1).reshape(-1,1),
                       coords=c,
                       dims=('neuroid','task'))
    return result

def result_to_df(SUC,corr_var_labels):
    df = SUC.neuroid.to_dataframe().reset_index()
    for label in corr_var_labels:
        df[label]=SUC.sel(task=label).values
    
    return df

In [None]:
def cat_parts(da,class_coord):
    class_vals = np.unique(da[class_coord].values)
    out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
    return out

def dprime_1d(vec,cut=1000):
    return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)

In [None]:
def gu_SUD(da_sets,neuroid_coord):
    pool = mp.Pool(6)
    results = [pool.apply(SUDprime,args=(da,neuroid_coord)) for da in da_sets]
    pool.close()
    pool.join()
    return results

In [None]:
def process_assembly(da):
    # Calculate dprime for single units
    print('Calculating dprime of all units...')
    SUdp_score = SUDprime(da,neuroid_coord='neuroid_id',)
    df_dp = result_to_df(SUdp_score,['category'])
    
    corr_vars = [pd.Series(da[v].values,name=v) for v in ['tx','ty']]
    corr = SUCorrelation(da,neuroid_coord='neuroid_id',correlation_vars=corr_vars)
    su_df = result_to_df(corr,['tx','ty'])
    df_dp = df_dp.sort_values(by='neuroid_id').reset_index().drop(columns='index')
    su_df = su_df.sort_values(by='neuroid_id').reset_index().drop(columns='index')
    su_df['category'] = df_dp.category

    return su_df

In [None]:
# region_sets = [xr_exclude_zero_dim(lg_both.sel(region=r),'neuroid_id') for r in np.unique(lg_both.region.values)]
# both_cat_results = gu_SUD(region_sets,'neuroid_id')
# both_SUdp_score = [SUDprime(rsets,neuroid_coord='neuroid_id',) for rsets in region_sets]


In [None]:
su_xent = process_assembly(next(ca))

In [None]:
su_xent = su

In [None]:
su_both = process_assembly(next(ca))

In [None]:
xent_SUdp_score = SUDprime(lg_xent,neuroid_coord='neuroid_id',)

In [None]:
both_df_dp = both_df_dp.sort_values(by='neuroid_id').reset_index().drop(columns='index')
su_both_df = su_both_df.sort_values(by='neuroid_id').reset_index().drop(columns='index')

In [None]:
su_both_df.head()

In [None]:
corr_vars_xent = [pd.Series(lg_xent[v].values,name=v) for v in ['tx','ty']]
corr_xent = SUCorrelation(lg_xent,neuroid_coord='neuroid_id',correlation_vars=corr_vars_xent)
su_xent_df = result_to_df(corr_xent,['tx','ty'])
su_xent_df['norm_ty'] = su_xent_df.ty

In [None]:
xent_df_dp = result_to_df(xent_SUdp_score,['category'])
su_xent_df['category'] = xent_df_dp.category

In [None]:
hi_data = process_dicarlo(neural_data,variation=6)

In [None]:
dicarlo_corr_vars = [
    pd.Series(hi_data['ty'],name='tx'),
    pd.Series(hi_data['tz'],name='ty'),
    pd.Series(hi_data['rxy'],name='rxy'),

]

# corr_dicarlo_med = SUCorrelation(med_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_med_corr_vars,exclude_zeros=True)
# dicarlo_med_df = result_to_df(corr_dicarlo_med,['tx','ty','rxy'])
# dicarlo_med_df['variation']=3

corr_dicarlo_hi = SUCorrelation(hi_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_corr_vars,exclude_zeros=True)
dicarlo_df = result_to_df(corr_dicarlo_hi, ['tx','ty','rxy'])
layer_map = {
    'V4':3,
    'IT':4
}

for reg,layer in zip(['V4','IT'],[3,4]):
    dicarlo_df['layer'] = [layer_map[r] for r in dicarlo_df.region]

In [None]:
dicarlo_SUdp_score = SUDprime(hi_data,neuroid_coord='neuroid_id',class_coord='category_name')

In [None]:
dicarlo_SUdp_df = result_to_df(dicarlo_SUdp_score,['category'])
dicarlo_df['category']=dicarlo_SUdp_df.category

In [None]:
mod_dirs = [os.path.join(proj_root,e.get_properties()['dir']) for e in exps]

In [None]:
mod_dirs

In [None]:
[e.get_parameters()['recon_weight'] for e in exps]

In [None]:
for dr,df in zip(mod_dirs,[su_xent,su_both]):
    fp = os.path.join(dr,'su_selectivity.pqt')
    print(fp)
    df.drop(columns=['neuroid','neuroid_id']).to_parquet(fp)
#     su_both_df.drop(columns=['neuroid','neuroid_id']).to_parquet(os.path.join(mod_dirs[0],'su_w_recon'))

In [None]:
dicarlo_df.drop(columns='neuroid').to_parquet(os.path.join(proj_root,'data','dicarlo.Majaj_processed'))

In [None]:
def plot_bars(y,df,by='region',order=None):
    if order is not None:
        subsets = order
    else:
        subsets = df[by].drop_duplicates().values
        
    plot_scale = 5
    fig,axs = plt.subplots(1,len(subsets),figsize=(plot_scale*len(subsets),plot_scale),sharex=True,sharey=True,
                           subplot_kw={
#                                'xlim':(0.0,0.8),
#                                'ylim':(0.0,0.8)
                           })
    
    for ax,sub in zip(axs,subsets):
        subsets = df[by].drop_duplicates().values
        sub_df = df.query('{} == "{}"'.format(by,sub))
        sns.barplot(x=by,y=y,ax=ax)

def plot_kde(x,y,df,by='region',order=None,xlim=(0.0,0.8),ylim=(0.0,0.8)):
    if order is not None:
        subsets = order
    else:
        subsets = df[by].drop_duplicates().values
        
    plot_scale = 5
    fig,axs = plt.subplots(1,len(subsets),figsize=(plot_scale*len(subsets)*0.8,plot_scale),sharex=True,sharey=True,
                           subplot_kw={
                               'xlim':xlim,
                               'ylim':ylim,
                           })
    
    for ax,sub in zip(axs,subsets):
        sub_df = df.query('{} == "{}"'.format(by,sub))
        sns.kdeplot(sub_df[x],sub_df[y],ax=ax)
        ax.set_title("{}: {}".format(by,sub))
    
    return fig,axs
#         sns.despine(ax)
# plot_bars(y='tx',df=both_df,by='layer',order=np.arange(5))

In [None]:
# set_style()

In [None]:
sns.set(font_scale=2)
sns.set_context('paper')

In [None]:
def topn_su_decode(df,n,props,**kwargs):
    order=[0,1,2,3,4]
    fig,axs = plt.subplots(1,len(props),figsize=(len(props)*4,3),sharey=True,**kwargs)
    sns.set_context('talk')
    for ax,prop in zip(axs,props):
        df_topn = pd.concat([df.query('layer == {}'.format(l)).nlargest(n,prop) for l in [3,4]])
        sns.barplot(x='layer',y=prop,order=order,data=df_topn,ax=ax,palette='magma')
        ax.set_xticklabels(['pixel','V1','?','V4','IT'])
#         ax.set_ylabel('d\'')
        ax.set_ylabel('perf')
        ax.set_title(prop)
        sns.despine(ax=ax)
    plt.tight_layout()

    
topn_su_decode(dicarlo_df,n=50,props=['category','ty'],
#                subplot_kw={'ylim':(0,0.5)},
              )

In [None]:
fig,axs = plt.subplots(2,2,figsize=(8,6),sharex=True,sharey=True)

mod_order=np.arange(5)
# mod_order = ['pixel','dense_2','dense_3','y_lat','z_lat']

sns.set_context('talk')
properties = ['tx','ty']
for ax_row,df,order in zip(axs,[su_xent,su_both,],[mod_order,mod_order]): 
    for ax,prop in zip(ax_row,properties):
        sns.barplot(x='layer',y=prop,order=order,data=df,ax=ax,palette='magma')
        sns.despine(ax=ax)
    
    ax_row[1].set_ylabel('')
    ax_row[0].set_ylabel('pearson')
    

for ax in axs[1]:
    ax.set_xticklabels(['pixel','1','2','3','4'])
#     ax.set_xticklabels(['pixel','L2','L3','y_lat','z_lat'])

for ax,prop in zip(axs[0],properties):
    ax.get_xaxis().set_visible(False)
    ax.set_title(prop)

plt.tight_layout()

In [None]:
properties = ['category','tx','ty']
fig,axs = plt.subplots(2,len(properties),figsize=(len(properties)*4,6),sharex=True,sharey=True)

# mod_order=np.arange(5)
mod_order = [0,2,3]
# mod_order = ['pixel','dense_2','dense_3','y_lat']

sns.set_context('talk')
for ax_row,df,order in zip(axs,[su_xent,su_both,],[mod_order,mod_order]): 
    for ax,prop in zip(ax_row,properties):
        sns.barplot(x='layer',y=prop,order=order,data=df,ax=ax,palette='magma')
        sns.despine(ax=ax)
    
    ax_row[1].set_ylabel('')
    ax_row[2].set_ylabel('')
    ax_row[0].set_ylabel('d\'')
    

for ax in axs[1]:
    ax.set_xticklabels(['pixel','2','3'])

#     ax.set_xticklabels(mod_order)

for ax,prop in zip(axs[0],properties):
    ax.get_xaxis().set_visible(False)
    ax.set_title(prop)

plt.tight_layout()

In [None]:
order=[0,1,2,3,4,5]
fig,ax = plt.subplots(1,1,figsize=(5,4.5))
sns.set_context('talk')
sns.boxplot(x='layer',y='category',order=order,data=dicarlo_df,ax=ax,palette='magma')
ax.set_xticklabels(['pixel','V1','?','V4','IT',''])
sns.despine(ax=ax)
plt.tight_layout()

In [None]:

# [['tx','ty','rxy','layer','region']]

In [None]:
topn_su_decode(dicarlo_df,n=100,props=['category','tx','ty'],
#                subplot_kw={'ylim':(0,0.8)},
              )

In [None]:
fig,axs = plt.subplots(1,2,figsize=(8,4),sharey=True,sharex=True)

mod_order=np.arange(5)

sns.set_context('talk')
for ax,df,order in zip(axs,[su_xent_df,su_both_df,],[mod_order,mod_order]): 
    sns.barplot(x='layer',y='tx',order=order,data=df,ax=ax)
axs[0].set_xticklabels(['pixel','1','2','3','4'])
axs[1].get_yaxis().set_visible(False)
plt.tight_layout()

In [None]:
sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

In [None]:
# sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

fig,axs = plot_kde('tx','ty',su_xent_df,by='layer',order=np.arange(5),)
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

#         ax.get_xaxis().set_visible(False)
    




In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_xent_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5))
plt.tight_layout()

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',dicarlo_df,by='region',order=['V4','IT'],xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('tx','category',dicarlo_df,by='region',order=['V4','IT'])


In [None]:
sns.scatterplot(x='tx',y='ty',data=dicarlo_df.query('region == "IT"'))
plt.ylim(0,0.5)
plt.xlim(0,0.5)

In [None]:
class MURegressor(object):
    def __init__(self,da,train_frac=0.8,n_splits=5,n_units=None,estimator=Ridge):
        if n_units is not None:
            self.neuroid_idxs = [np.array([random.randrange(len(da.neuroid_id)) for _ in range(n_units)]) for _ in range(n_splits)]
        
        self.original_data = da
        self.train_frac = train_frac
        self.n_splits = n_splits
        
        splits = [split_assembly(self.original_data[:,n_idxs]) for n_idxs in tqdm(self.neuroid_idxs,total=n_splits,desc='CV-splitting')]
        self.train = [tr for tr,te in splits]
        self.test = [te for tr,te in splits]
        
        
        self.estimators = [estimator() for _ in range(n_splits)]
        
    def fit(self,y_coord):
        # Get Training data
        for mod,train in tqdm(zip(self.estimators,self.train),total=len(self.train),desc='fitting'):
#             print(train)
            mod.fit(X=train.values,y=train[y_coord])
    
        return self
    
    def predict(self,X=None):
        if X is not None:
            return [e.predict(X) for e in self.estimators]
        else:
            return [e.predict(te.values) for e,te in zip(self.estimators,self.test)]
        
    def score(self,y_coord):
        return [e.score(te.values,te[y_coord].values) for e,te in zip(self.estimators,self.test)]
    
def stratified_regressors(data, filt='region',n_units=126,y_coords=['ty','tz'],task_names=None,estimator=Ridge):
    subsets = np.unique(data[filt].values)
    if task_names is None:
        task_names = y_coords
    dfs = []
    for y,task in zip(y_coords,task_names):
        print('regressing {}...'.format(y))
        regressors = {k:MURegressor(data.sel(**{filt:k}),n_units=n_units,estimator=Ridge).fit(y_coord=y) for k in subsets}
        df = pd.DataFrame.from_records({k:v.score(y_coord=y) for k,v in regressors.items()})
        df = df.melt(var_name='region',value_name='performance')
        df['task']=task
        dfs.append(df)
    
    return pd.concat(dfs)

In [None]:
properties = ['tx','ty']
mu_both_df = stratified_regressors(lg_both,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_both_df)

In [None]:
mu_xent_df = stratified_regressors(lg_xent,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_xent_df)

In [None]:
plot_kde(x='tx',y='ty',df=both_df,by='layer',order=np.arange(5))

In [None]:
plot_kde(x='tx',y='ty',df=xent_df,by='layer',order=np.arange(5))