In [1]:
import os
import random
import json
import hashlib
import numpy as np
import pandas as pd
import xarray
import neptune
import matplotlib.pyplot as plt
import seaborn as sns
from keras.models import Model
from tqdm import tqdm as tqdm
import multiprocessing as mp

from sklearn.linear_model import LinearRegression,Ridge,RidgeCV

import brainscore
from brainscore.assemblies import walk_coords,split_assembly
from brainscore.assemblies import split_assembly
from brainscore.metrics import Score

from brainio_base.assemblies import DataAssembly

from scipy.stats import pearsonr

from src.results.experiments import _DateExperimentLoader
from src.results.utils import raw_to_xr, dprime
from src.results.neptune import get_model_files, load_models, load_assemblies, load_params, load_properties,prep_assemblies
from src.data_loader import Shifted_Data_Loader
from src.data_generator import ShiftedDataBatcher

def set_style():
    # This sets reasonable defaults for font size for
    # a figure that will go in a paper
    sns.set_context("paper")
    
    # Set the font to be serif, rather than sans
    sns.set(font='serif')
    
    # Make the background white, and specify the
    # specific font family
    sns.set_style("white", {
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Palatino", "serif"]
    })

Using TensorFlow backend.


In [2]:
os.environ['NEPTUNE_API_TOKEN']="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5tbCIsImFwaV9rZXkiOiI3ZWExMTlmYS02ZTE2LTQ4ZTktOGMxMi0wMDJiZTljOWYyNDUifQ=="
neptune.init('elijahc/DuplexAE')
neptune.set_project('elijahc/DuplexAE')

Project(elijahc/DuplexAE)

In [3]:
proj_root = '/home/elijahc/projects/vae'

In [None]:
exps = neptune.project.get_experiments(id=['DPX-30'])
# mod = next(load_models(proj_root,exps))

In [None]:
e = exps[0]

In [None]:
exp_dir = os.path.join(proj_root,e.get_properties()['dir'])
PARAMS = e.get_parameters()
PROPS = e.get_properties()

In [None]:
# DB = ShiftedDataBatcher(PROPS['dataset'],translation=PARAMS['im_translation'],bg=PARAMS['bg'],
#                         blend=None,
# #                         blend='difference',
#                         batch_size=PARAMS['batch_sz'],
#                        )

In [None]:
PROPS

In [None]:
DL = Shifted_Data_Loader('fashion_mnist',rotation=None,translation=0.75,bg='natural',flatten=False)
sx_test = DL.sx_test

In [None]:
plt.imshow(sx_test[250].reshape(56,56),cmap='gray')

In [None]:
slug = [(dx,dy,float(lab),float(rxy)) for dx,dy,rxy,lab in zip(DL.dx[1]-14,DL.dy[1]-14,DL.dtheta[1],DL.y_test)]
cas = prep_assemblies(proj_root,exps,test_data=sx_test,slug=slug)

In [None]:
encodings,depths,stim_set = next(cas)

In [None]:
def gen_conv_assemblies(encodings,depths,stim_set,n=5):
    enc = {k:encodings[k] for k in ['pixel','y_enc','z_enc']}
    for i in trange(n):
        enc.update({k:encodings[k][:,:,i] for k in ['conv_4','conv_3','conv_2','conv_1']})
        yield raw_to_xr(enc,depths,stim_set)

In [None]:
xrs = list(gen_conv_assemblies(encodings,depths,stim_set,n=10))

In [None]:
for r in np.unique(xrs[3].region.values):
    print(r)
    print(xrs[3].sel(region=r).shape)

In [None]:
# stim_set = pd.DataFrame({'dx':DL.dx[1]-14,'dy':DL.dy[1]-14,'numeric_label':DL.y_test,'rxy':DL.dtheta[1],'image_id':image_id})

# ca = create_assemblies(proj_root,exps,test_data=sx_test,slug=slug)

In [None]:
# lg_both = xrs[0]
# lg_xent = xrs[1]

In [None]:
# exps[0].get_properties()['dir']
# mod_dir = os.path.join(proj_root,exps[0].get_properties()['dir'])
# save_assembly(out,run_dir=mod_dir,fname='dataset.nc',
#     format='NETCDF3_64BIT',
# )

In [None]:
# das = load_assemblies(proj_root,exps)
# lg_both = next(das)

In [None]:
# lg_xent = lg.assemblies[0]
# lg_both = lg.assemblies[1]

In [33]:
neural_data = brainscore.get_assembly(name="dicarlo.Majaj2015")
neural_data.load()
stimulus_set = neural_data.attrs['stimulus_set']

  xr_data.set_index(append=True, inplace=True, **coords_d)


In [34]:
def process_dicarlo(assembly,avg_repetition=True,variation=3,tasks=['ty','tz','rxy']):
    stimulus_set = assembly.attrs['stimulus_set']
    stimulus_set['dy_deg'] = stimulus_set.tz*stimulus_set.degrees
    stimulus_set['dx_deg'] = stimulus_set.ty*stimulus_set.degrees
    stimulus_set['dy_px'] = stimulus_set.dy_deg*32
    stimulus_set['dx_px'] = stimulus_set.dx_deg*32
    
    assembly.attrs['stimulus_set'] = stimulus_set
    
    data = assembly.sel(variation=variation)
    groups = ['category_name', 'object_name', 'image_id']+tasks
    if not avg_repetition:
        groups.append('repetition')
        
    data = data.multi_groupby(groups)     # (2)
    data = data.mean(dim='presentation')
    data = data.squeeze('time_bin')    #   (3)
    data.attrs['stimulus_set'] = stimulus_set.query('variation == {}'.format(variation))
    data = data.T
    
    return data

In [5]:
# pos_samples = .where(lg_both.numeric_label==1,drop=True).values[:,250]
# neg_samples = lg_both.where(lg_both.numeric_label!=1, drop=True).values[:,250]

In [6]:
def xr_exclude_zero_dim(da,neuroid_coord):
    nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
    return da[:,nz_neuroids]

In [161]:
def SUCorrelation(da,neuroid_coord,correlation_vars,exclude_zeros=True,progress=True):
    if exclude_zeros:
        nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
        da = da[:,nz_neuroids]
    
    correlations = np.empty((len(da[neuroid_coord]),len(correlation_vars)))
    iterator = enumerate(da[neuroid_coord].values)
    if progress:
        iterator = tqdm(iterator,
                        total=len(da[neuroid_coord]),
                        mininterval=5,desc='pearsonr'
                       )
    
    for i,nid in iterator:
        for j,prop in enumerate(correlation_vars):
            n_act = da.sel(**{neuroid_coord:nid}).squeeze()
            r,p = pearsonr(n_act,prop)
            correlations[i,j] = np.abs(r)

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',[v.name for v in correlation_vars])
#     print(neuroid_dim)
    result = Score(correlations,
                       coords=c,
                       dims=('neuroid','task'))
    return result

def SUDprime(da,neuroid_coord='neuroid_id',class_coord='numeric_label',exclude_zeros=True,filters={},progress=True):    
    if exclude_zeros:
        nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
        da = da[:,nz_neuroids]
    if len(filters.items()) > 0:
        filts = [da[k].isin(v).values for k,v in filters.items()]
        if len(filts) > 1:
            logics = np.logical_or(*filts)
        else:
            logics = filts[0]
        da = da[:,logics]
    
    def cat_parts(da,class_coord):
        out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
        return np.array(out)

    def dprime_1d(vec,cut=1000):
        return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)
    
#     class_vals = np.unique(da[class_coord].values)
#     parts = [((da[class_coord]==c).values,(da[class_coord]!=c).values) for c in class_vals]
    class_vals = np.unique(da[class_coord].values)

    c_parts = cat_parts(da,class_coord)
    dprimes = np.empty((len(da[neuroid_coord]),len(class_vals)))
    iterator = enumerate(da[neuroid_coord].values)
    if progress:
        iterator = tqdm(iterator,
                        total=dprimes.shape[0],mininterval=5,
                        desc='dprime',
                       )

    for i,nid in iterator:
#         da_n = da.sel(**{neuroid_coord:nid})
        dpn = np.apply_along_axis(dprime_1d,1,c_parts[:,:,i])
        dprimes[i] = dpn

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',[class_coord])
#     print(neuroid_dim)
    result = Score(dprimes.max(axis=1).reshape(-1,1),
                       coords=c,
                       dims=('neuroid','task'))
    return result

def result_to_df(SUC,corr_var_labels):
    df = SUC.neuroid.to_dataframe().reset_index()
    for label in corr_var_labels:
        df[label]=SUC.sel(task=label).values
    
    return df

In [8]:
def cat_parts(da,class_coord):
    class_vals = np.unique(da[class_coord].values)
    out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
    return out

def dprime_1d(vec,cut=1000):
    return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)

In [9]:
def gu_SUD(da_sets,neuroid_coord):
    pool = mp.Pool(6)
    results = [pool.apply(SUDprime,args=(da,neuroid_coord)) for da in da_sets]
    pool.close()
    pool.join()
    return results

In [175]:
def process_assembly(da,class_coords=['numeric_label'],r_vars=['tx','ty','rxy'],filters={},progress=True):
    # Calculate dprime for single units
    print('Calculating dprime all units...')
    if len(filters.items()) > 0:
        filts = [da[k].isin(v).values for k,v in filters.items()]
        if len(filts) > 1:
            logics = np.logical_or(*filts)
        else:
            logics = filts[0]
        da = da[:,logics]
    
    df_dps = []
    for class_coord in class_coords:
        print('- dprime of {}...'.format(class_coord))
        SUdp_score = SUDprime(da,neuroid_coord='neuroid_id',class_coord=class_coord,progress=progress)
        df_dps.append(result_to_df(SUdp_score,[class_coord]).sort_values(by='neuroid_id'))
    
    corr_vars = [pd.Series(da[v].values,name=v) for v in r_vars]
    print('Calculating pearsonr of all units...')
    corr = SUCorrelation(da,neuroid_coord='neuroid_id',correlation_vars=corr_vars)
    su_df = result_to_df(corr,r_vars).sort_values(by='neuroid_id')
    for dfdp,cc in zip(df_dps,class_coords):
        su_df = su_df.merge(dfdp[['neuroid_id',cc]],on='neuroid_id')
#     df_dp = df_dp.sort_values(by='neuroid_id').reset_index().drop(columns='index')
#     su_df = su_df.reset_index().drop(columns='index')
#     su_df['category'] = df_dp.category

    return su_df

In [None]:
# region_sets = [xr_exclude_zero_dim(lg_both.sel(region=r),'neuroid_id') for r in np.unique(lg_both.region.values)]
# both_cat_results = gu_SUD(region_sets,'neuroid_id')
# both_SUdp_score = [SUDprime(rsets,neuroid_coord='neuroid_id',) for rsets in region_sets]


In [11]:
pix_da = xarray.open_dataarray(os.path.join(proj_root,'data','dicarlo_images','hi_pix.nc'))

In [18]:
pix_da = pix_da.rename(ty='tx',tz='ty')

In [19]:
pix_da = pix_da.set_index({
    'neuroid':['neuroid_id','region','subregion','layer'],
    'presentation':['image_id','object_name','category_name','tx','ty','rxy']
                 })



In [22]:
idxs = np.random.choice(np.arange(len(pix_da.neuroid_id.values)),size=int(len(pix_da.neuroid_id.values)*0.1),replace=False)

In [23]:
npda = pix_da[:,idxs]

In [130]:
cat_dp_score = SUDprime(npda,neuroid_coord='neuroid_id',class_coord='category_name',)


  dp = num / div

  1%|          | 74/6553 [00:00<00:08, 732.77it/s][A
  2%|▏         | 149/6553 [00:00<00:08, 736.53it/s][A
  3%|▎         | 189/6553 [00:00<00:10, 586.38it/s][A
  4%|▎         | 244/6553 [00:00<00:11, 572.15it/s][A
  5%|▍         | 302/6553 [00:00<00:10, 571.57it/s][A
  5%|▌         | 359/6553 [00:00<00:10, 571.08it/s][A
  6%|▋         | 416/6553 [00:00<00:10, 570.04it/s][A
  7%|▋         | 473/6553 [00:00<00:10, 568.64it/s][A
  8%|▊         | 530/6553 [00:00<00:10, 568.04it/s][A
  9%|▉         | 587/6553 [00:01<00:10, 566.83it/s][A
 10%|█         | 657/6553 [00:01<00:09, 599.91it/s][A
 11%|█         | 731/6553 [00:01<00:09, 634.38it/s][A
 12%|█▏        | 805/6553 [00:01<00:08, 660.95it/s][A
 13%|█▎        | 880/6553 [00:01<00:08, 684.20it/s][A
 15%|█▍        | 955/6553 [00:01<00:07, 700.92it/s][A
 16%|█▌        | 1030/6553 [00:01<00:07, 712.87it/s][A
 17%|█▋        | 1105/6553 [00:01<00:07, 721.87it/s][A
 18%|█▊        | 1178/6553 [00:01<00:07, 724.

In [131]:
obj_dp_score = SUDprime(npda,neuroid_coord='neuroid_id',class_coord='object_name',)


  dp = num / div

  0%|          | 9/6553 [00:00<01:16, 85.88it/s][A
  0%|          | 18/6553 [00:00<01:17, 84.52it/s][A
  0%|          | 27/6553 [00:00<01:17, 84.64it/s][A
  1%|          | 36/6553 [00:00<01:16, 84.91it/s][A
  1%|          | 45/6553 [00:00<01:16, 85.23it/s][A
  1%|          | 54/6553 [00:00<01:17, 84.23it/s][A
  1%|          | 63/6553 [00:00<01:17, 83.76it/s][A
  1%|          | 72/6553 [00:00<01:17, 83.40it/s][A
  1%|          | 81/6553 [00:00<01:17, 83.99it/s][A
  1%|▏         | 90/6553 [00:01<01:17, 83.01it/s][A
  2%|▏         | 99/6553 [00:01<01:18, 82.01it/s][A
  2%|▏         | 108/6553 [00:01<01:17, 82.70it/s][A
  2%|▏         | 117/6553 [00:01<01:17, 83.23it/s][A
  2%|▏         | 126/6553 [00:01<01:16, 83.85it/s][A
  2%|▏         | 135/6553 [00:01<01:16, 83.96it/s][A
  2%|▏         | 144/6553 [00:01<01:16, 83.65it/s][A
  2%|▏         | 153/6553 [00:01<01:16, 84.15it/s][A
  2%|▏         | 162/6553 [00:01<01:16, 83.78it/s][A
  3%|▎         | 171/

In [162]:
r_vars = ['tx','ty','rxy']
corr_vars = [pd.Series(npda[v].values,name=v) for v in r_vars]
print('Calculating pearsonr of all units...')
corr = SUCorrelation(npda,neuroid_coord='neuroid_id',correlation_vars=corr_vars,progress=True)
su_df = result_to_df(corr,r_vars).sort_values(by='neuroid_id')

Calculating pearsonr of all units...





pearsonr:   0%|          | 0/6553 [00:00<?, ?it/s][A[A[A


pearsonr:   2%|▏         | 139/6553 [00:05<03:52, 27.63it/s][A[A[A


pearsonr:   4%|▍         | 277/6553 [00:10<03:47, 27.57it/s][A[A[A


pearsonr:   6%|▋         | 416/6553 [00:15<03:42, 27.61it/s][A[A[A


pearsonr:   8%|▊         | 555/6553 [00:20<03:37, 27.61it/s][A[A[A


pearsonr:  11%|█         | 694/6553 [00:25<03:31, 27.66it/s][A[A[A


pearsonr:  13%|█▎        | 833/6553 [00:30<03:26, 27.69it/s][A[A[A


pearsonr:  15%|█▍        | 972/6553 [00:35<03:21, 27.69it/s][A[A[A


pearsonr:  17%|█▋        | 1111/6553 [00:40<03:16, 27.72it/s][A[A[A


pearsonr:  19%|█▉        | 1250/6553 [00:45<03:11, 27.74it/s][A[A[A


pearsonr:  21%|██        | 1390/6553 [00:50<03:05, 27.81it/s][A[A[A


pearsonr:  23%|██▎       | 1529/6553 [00:55<03:00, 27.77it/s][A[A[A


pearsonr:  25%|██▌       | 1667/6553 [01:00<02:56, 27.71it/s][A[A[A


pearsonr:  28%|██▊       | 1807/6553 [01:05<02:51, 27.74it/s][A[

In [164]:
su_df['variation']=6

In [139]:
cat_dp_score = cat_dp_score.assign_coords(task=('task',['category_name']))
obj_dp_score = obj_dp_score.assign_coords(task=('task',['object_name']))

In [166]:
cat_df = result_to_df(cat_dp_score,['category_name']).sort_values(by='neuroid_id')
obj_df = result_to_df(obj_dp_score,['object_name']).sort_values(by='neuroid_id')
cat_df['variation']=6
obj_df['variation']=6

In [167]:
for dpdf,cc in zip([cat_df,obj_df],['category_name','object_name']):
    su_df = su_df.merge(dpdf[['neuroid_id',cc]],on='neuroid_id')

In [170]:
su_df.head()

Unnamed: 0,neuroid_id,region,subregion,layer,neuroid,tx,ty,rxy,variation,category_name,object_name
0,pixel_10000,pixel,pixel,0,"(pixel_10000, pixel, pixel, 0)",,,,6,,
1,pixel_10003,pixel,pixel,0,"(pixel_10003, pixel, pixel, 0)",,,,6,,
2,pixel_10012,pixel,pixel,0,"(pixel_10012, pixel, pixel, 0)",,,,6,,
3,pixel_10063,pixel,pixel,0,"(pixel_10063, pixel, pixel, 0)",0.02286,0.039999,0.018588,6,0.082507,0.040494
4,pixel_10077,pixel,pixel,0,"(pixel_10077, pixel, pixel, 0)",0.018335,0.044148,0.009387,6,0.079289,0.044056


In [31]:
su_dicarlo = su_dicarlo.dropna()

In [43]:
hda = hi_data.to_dataset()['dicarlo.Majaj2015']

In [58]:
tx = pd.Series(hda.ty.values,name='tx')
ty = pd.Series(hda.tz.values,name='ty')
hda = hda.reset_index(['ty','tz'],drop=True)

In [73]:
hda = hda.assign_coords(tx=('presentation',tx.values),ty=('presentation',ty.values))
hda = hda.set_index({'presentation':['tx','ty']},append=True)

In [77]:
su_dicarlo.head()

Unnamed: 0,neuroid_id,region,subregion,layer,neuroid,tx,ty,category
3,pixel_10063,pixel,pixel,0,"(pixel_10063, pixel, pixel, 0)",0.02286,0.039999,0.082507
4,pixel_10077,pixel,pixel,0,"(pixel_10077, pixel, pixel, 0)",0.018335,0.044148,0.079289
5,pixel_10080,pixel,pixel,0,"(pixel_10080, pixel, pixel, 0)",0.021876,0.048534,0.10017
6,pixel_10099,pixel,pixel,0,"(pixel_10099, pixel, pixel, 0)",0.007417,0.068462,0.076571
7,pixel_10105,pixel,pixel,0,"(pixel_10105, pixel, pixel, 0)",0.021272,0.077362,0.091232


In [176]:
su_hi_var = process_assembly(hda,class_coords=['category_name','object_name'])

Calculating dprime all units...
- dprime of category_name...





dprime:   0%|          | 0/296 [00:00<?, ?it/s][A[A[A


  xr_data.set_index(append=True, inplace=True, **coords_d)


- dprime of object_name...





dprime:   0%|          | 0/296 [00:00<?, ?it/s][A[A[A


  xr_data.set_index(append=True, inplace=True, **coords_d)



pearsonr:   0%|          | 0/296 [00:00<?, ?it/s][A[A[A

Calculating pearsonr of all units...





pearsonr:  71%|███████▏  | 211/296 [00:05<00:02, 42.03it/s][A[A[A


  xr_data.set_index(append=True, inplace=True, **coords_d)


In [178]:
su_hi_var['variation']=6

In [185]:
su_dicarlo_hi = su_df.append(su_hi_var,sort=False)

In [187]:
su_dicarlo_hi.groupby('region').count()

Unnamed: 0_level_0,neuroid_id,subregion,layer,neuroid,tx,ty,rxy,variation,category_name,object_name,arr,col,hemisphere,animal,y,x,row
region,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
IT,168,168,0,168,168,168,168,168,168,168,168,168,168,168,168,168,168
V4,128,128,0,128,128,128,128,128,128,128,128,128,128,128,128,128,128
pixel,6553,6553,6553,6553,5110,5110,5110,6553,5110,5110,0,0,0,0,0,0,0


In [181]:
idxs = (su_dicarlo_hi.region=='IT').values
v4_idxs = (su_dicarlo_hi.region=='V4').values

# su_dicarlo_hi.loc[:,['layer']]=4

In [192]:
su_dicarlo_hi.loc[idxs,['layer']]=4
su_dicarlo_hi.loc[v4_idxs, ['layer']]=3
su_dicarlo_hi.layer = su_dicarlo_hi.layer.astype(np.int)
# su_dicarlo_hi.query('region == "V4"')['layer']=3

In [196]:
su_dicarlo_hi.object_name = np.abs(su_dicarlo_hi.object_name.values)

In [197]:
su_dicarlo_hi = su_dicarlo_hi[['neuroid_id','region','subregion','layer','neuroid','variation','tx','ty','rxy','object_name','category_name']]


In [198]:
su_dicarlo_hi.groupby('subregion').count()

Unnamed: 0_level_0,neuroid_id,region,layer,neuroid,variation,tx,ty,rxy,object_name,category_name
subregion,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
V4,128,128,128,128,128,128,128,128,128,128
aIT,17,17,17,17,17,17,17,17,17,17
cIT,75,75,75,75,75,75,75,75,75,75
pIT,76,76,76,76,76,76,76,76,76,76
pixel,6553,6553,6553,6553,6553,5110,5110,5110,5110,5110


In [199]:
fp = os.path.join(proj_root,'data','su_selectivity_dicarlo_hi_var.pqt')
su_dicarlo_hi.drop(columns=['neuroid']).to_parquet(fp)

In [None]:
su_fixed = process_assembly(xrs[0],filters={'layer':[0]})

In [None]:
su_xent = [process_assembly(xrs[i],filters={'layer':[1,2,3,4,5]}) for i in range(3)]
sudfs = [df.append(su_fixed) for df in su_xent]

In [None]:
sudfs[0].groupby('region').count()

In [None]:
su_both = process_assembly(next(ca))

In [None]:
xent_SUdp_score = SUDprime(lg_xent,neuroid_coord='neuroid_id',)

In [None]:
both_df_dp = both_df_dp.sort_values(by='neuroid_id').reset_index().drop(columns='index')
su_both_df = su_both_df.sort_values(by='neuroid_id').reset_index().drop(columns='index')

In [None]:
su_both_df.head()

In [None]:
corr_vars_xent = [pd.Series(lg_xent[v].values,name=v) for v in ['tx','ty']]
corr_xent = SUCorrelation(lg_xent,neuroid_coord='neuroid_id',correlation_vars=corr_vars_xent)
su_xent_df = result_to_df(corr_xent,['tx','ty'])
su_xent_df['norm_ty'] = su_xent_df.ty

In [None]:
xent_df_dp = result_to_df(xent_SUdp_score,['category'])
su_xent_df['category'] = xent_df_dp.category

In [35]:
hi_data = process_dicarlo(neural_data,variation=6)

  result.reset_index(self.multi_group_name, drop=True, inplace=True)
  result.set_index(append=True, inplace=True, **{self.multi_group_name: self.group_coord_names})


In [None]:
dicarlo_corr_vars = [
    pd.Series(hi_data['ty'],name='tx'),
    pd.Series(hi_data['tz'],name='ty'),
    pd.Series(hi_data['rxy'],name='rxy'),

]

# corr_dicarlo_med = SUCorrelation(med_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_med_corr_vars,exclude_zeros=True)
# dicarlo_med_df = result_to_df(corr_dicarlo_med,['tx','ty','rxy'])
# dicarlo_med_df['variation']=3

corr_dicarlo_hi = SUCorrelation(hi_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_corr_vars,exclude_zeros=True)
dicarlo_df = result_to_df(corr_dicarlo_hi, ['tx','ty','rxy'])
layer_map = {
    'V4':3,
    'IT':4
}

for reg,layer in zip(['V4','IT'],[3,4]):
    dicarlo_df['layer'] = [layer_map[r] for r in dicarlo_df.region]

In [None]:
from matplotlib import image
hi_ims = []
for imp in hi_data.image_id.values:
    im = image.imread(stimulus_set.get_image(imp))
    hi_ims.append(im[:,:,0])
    
hi_ims = np.stack(hi_ims,axis=0)
hi_ims = hi_ims.reshape(hi_ims.shape[0],np.prod(hi_ims.shape[1:]))
# pix_idxs = np.random.choice(hi_ims.shape[-1],size=int(hi_ims.shape[-1]/2),replace=False)
enc = {'pixel':hi_ims}
p_idx = hi_data.indexes['presentation']
# n_idx_orig = hi_data.indexes['neuroid']
neuroid_n = enc['pixel'].shape[-1]
n_coords = [
            pd.Series(['{}_{}'.format('pixel',i) for i in np.arange(neuroid_n)],name='neuroid_id'),
            pd.Series([0]*neuroid_n,name='layer'),
            pd.Series(['pixel']*neuroid_n,name='region'),
            pd.Series(['pixel']*neuroid_n,name='subregion'),
#             pd.Series(hi_data.arr.values,name='arr'),
#             pd.Series(hi_data.animal.values, name='animal')
        ]
n_idx = pd.MultiIndex.from_arrays(n_coords)

In [None]:
coords = {'presentation':p_idx,'neuroid':n_idx}

pix_xr = xarray.DataArray(enc['pixel'].astype('float32'),
                         coords=coords,
                         dims=['presentation','neuroid'])
del hi_ims
del enc

In [None]:
from src.results.utils import save_assembly

In [None]:
save_assembly(pix_xr,os.path.join(proj_root,'data','dicarlo_images'),fname='hi_pix.nc')

In [None]:
dicarlo_SUdp_df = result_to_df(dicarlo_SUdp_score,['category'])
dicarlo_df['category']=dicarlo_SUdp_df.category

In [None]:
dicarlo_df.groupby('subregion').count()

In [None]:
mod_dirs = [os.path.join(proj_root,e.get_properties()['dir']) for e in exps]

In [None]:
mod_dirs

In [None]:
[e.get_parameters()['recon_weight'] for e in exps]

In [None]:
for i,dr,df in zip(np.arange(3),mod_dirs*3,sudfs):
    fp = os.path.join(dr,'su_selectivity_{}.pqt'.format(i+1))
    print(fp)
    df.drop(columns=['neuroid','neuroid_id']).to_parquet(fp)
#     su_both_df.drop(columns=['neuroid','neuroid_id']).to_parquet(os.path.join(mod_dirs[0],'su_w_recon'))

In [None]:
depths = {'pixel':0}

In [None]:
ss = pd.DataFrame.from_records(hi_data.stimulus_set.to_records()).drop(columns=['index'])
ss = ss.rename(columns={'ty':'tx','tz':'ty','dy_px':'dy','dx_px':'dx'})

In [None]:
np.unique(ss.category_name.values)

In [None]:
hd = hi_data.to_dataset()
hda = hd['dicarlo.Majaj2015']

In [None]:
hda = hda.reset_index(['col','arr','hemisphere','animal','y','x','row'],drop=True)

In [None]:
coords = {'presentation':p_idx,'neuroid_id':n_coords[0],'region':n_coords[1]}
# coords.update({c.name:c.values for c in n_coords[0:1]})
# coords['layer']=[0]
# coords['region']=['pixel']

In [None]:
im = image.imread(stimulus_set.get_image(hi_data.image_id.values[50]))

In [None]:
plt.imshow(im)

In [None]:
sm_ims = np.load(os.path.join(proj_root,'data','dicarlo_images','sm_imgs_56x56.npy'))

In [None]:
plt.imshow(sm_ims[120],cmap='gray')

In [None]:
dicarlo_df.drop(columns='neuroid').to_parquet(os.path.join(proj_root,'data','dicarlo.Majaj_processed'))

In [None]:
sns.set(font_scale=2)
sns.set_context('paper')

In [None]:
fig,axs = plt.subplots(1,2,figsize=(8,4),sharey=True,sharex=True)

mod_order=np.arange(5)

sns.set_context('talk')
for ax,df,order in zip(axs,[su_xent_df,su_both_df,],[mod_order,mod_order]): 
    sns.barplot(x='layer',y='tx',order=order,data=df,ax=ax)
axs[0].set_xticklabels(['pixel','1','2','3','4'])
axs[1].get_yaxis().set_visible(False)
plt.tight_layout()

In [None]:
sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

In [None]:
# sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

fig,axs = plot_kde('tx','ty',su_xent_df,by='layer',order=np.arange(5),)
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

#         ax.get_xaxis().set_visible(False)
    




In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_xent_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5))
plt.tight_layout()

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',dicarlo_df,by='region',order=['V4','IT'],xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('tx','category',dicarlo_df,by='region',order=['V4','IT'])


In [None]:
sns.scatterplot(x='tx',y='ty',data=dicarlo_df.query('region == "IT"'))
plt.ylim(0,0.5)
plt.xlim(0,0.5)

In [None]:
class MURegressor(object):
    def __init__(self,da,train_frac=0.8,n_splits=5,n_units=None,estimator=Ridge):
        if n_units is not None:
            self.neuroid_idxs = [np.array([random.randrange(len(da.neuroid_id)) for _ in range(n_units)]) for _ in range(n_splits)]
        
        self.original_data = da
        self.train_frac = train_frac
        self.n_splits = n_splits
        
        splits = [split_assembly(self.original_data[:,n_idxs]) for n_idxs in tqdm(self.neuroid_idxs,total=n_splits,desc='CV-splitting')]
        self.train = [tr for tr,te in splits]
        self.test = [te for tr,te in splits]
        
        
        self.estimators = [estimator() for _ in range(n_splits)]
        
    def fit(self,y_coord):
        # Get Training data
        for mod,train in tqdm(zip(self.estimators,self.train),total=len(self.train),desc='fitting'):
#             print(train)
            mod.fit(X=train.values,y=train[y_coord])
    
        return self
    
    def predict(self,X=None):
        if X is not None:
            return [e.predict(X) for e in self.estimators]
        else:
            return [e.predict(te.values) for e,te in zip(self.estimators,self.test)]
        
    def score(self,y_coord):
        return [e.score(te.values,te[y_coord].values) for e,te in zip(self.estimators,self.test)]
    
def stratified_regressors(data, filt='region',n_units=126,y_coords=['ty','tz'],task_names=None,estimator=Ridge):
    subsets = np.unique(data[filt].values)
    if task_names is None:
        task_names = y_coords
    dfs = []
    for y,task in zip(y_coords,task_names):
        print('regressing {}...'.format(y))
        regressors = {k:MURegressor(data.sel(**{filt:k}),n_units=n_units,estimator=Ridge).fit(y_coord=y) for k in subsets}
        df = pd.DataFrame.from_records({k:v.score(y_coord=y) for k,v in regressors.items()})
        df = df.melt(var_name='region',value_name='performance')
        df['task']=task
        dfs.append(df)
    
    return pd.concat(dfs)

In [None]:
properties = ['tx','ty']
mu_both_df = stratified_regressors(lg_both,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_both_df)

In [None]:
mu_xent_df = stratified_regressors(lg_xent,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_xent_df)

In [None]:
plot_kde(x='tx',y='ty',df=both_df,by='layer',order=np.arange(5))

In [None]:
plot_kde(x='tx',y='ty',df=xent_df,by='layer',order=np.arange(5))