In [8]:
import os
import random
import numpy as np
import pandas as pd
import xarray
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm as tqdm
import multiprocessing as mp


from sklearn.linear_model import LinearRegression,Ridge,RidgeCV

import brainscore
from brainscore.assemblies import walk_coords,split_assembly
from brainscore.assemblies import split_assembly
from brainscore.metrics import Score

from brainio_base.assemblies import DataAssembly

from scipy.stats import pearsonr

from src.results.experiments import _DateExperimentLoader

def set_style():
    # This sets reasonable defaults for font size for
    # a figure that will go in a paper
    sns.set_context("paper")
    
    # Set the font to be serif, rather than sans
    sns.set(font='serif')
    
    # Make the background white, and specify the
    # specific font family
    sns.set_style("white", {
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Palatino", "serif"]
    })

In [2]:
lg = _DateExperimentLoader('2019-06-04')
lg.load_configs()
configs = pd.DataFrame.from_records(lg.configs)

In [3]:
configs.query('bg_noise == 0.05 & enc_arch == "feedforward"')

Unnamed: 0,batch_size,bg_noise,dataset,dec_blocks,ecc_max,enc_arch,enc_blocks,enc_layers,epochs,label_corruption,...,project,recon,rot_max,run_dir,seed,uploaded_by,xcov,xent,y_dim,z_dim
2,512,0.05,fashion_mnist,"[4, 2, 1]",0.6,feedforward,,"[3000, 2000, 500]",90,0.0,...,vae,0,0,/home/elijahc/projects/vae/logs/0701_145223_fa...,7,elijahc,0,15,35,35
12,512,0.05,fashion_mnist,"[4, 2, 1]",0.6,feedforward,,"[3000, 2000, 500]",10,0.0,...,vae,25,0,/home/elijahc/projects/vae/logs/0701_152033_fa...,7,elijahc,0,15,35,35


In [4]:
idxs = configs.query('bg_noise == 0.05 & enc_arch == "feedforward"').index.values
lg.load_assemblies(subset=list(idxs))

  xr_data.set_index(append=True, inplace=True, **coords_d)


[<xarray.DataAssembly (presentation: 10000, neuroid: 8706)>
 array([[ 2.918898,  1.481666, -6.121837, ...,  0.      ,  0.      ,  0.      ],
        [-1.232504, -0.251372,  0.99593 , ...,  0.      ,  0.      ,  0.      ],
        [ 7.400597,  7.941635, 10.374889, ...,  0.      ,  0.      ,  0.      ],
        ...,
        [ 3.823743, -6.502236,  3.107507, ...,  0.      ,  0.      ,  0.      ],
        [ 5.338862,  5.564245,  5.631349, ...,  0.      ,  0.      ,  0.      ],
        [-0.750803,  0.913194, -2.584277, ...,  0.      ,  0.      ,  0.      ]],
       dtype=float32)
 Coordinates:
   * presentation   (presentation) MultiIndex
   - image_id       (presentation) object '6faf23dc748fc2fa6999ca4f16b90358' ... '691246995f7c97c4af84eb20fffe2128'
   - object_name    (presentation) object 'Ankle boot' 'Pullover' ... 'Dress'
   - dx             (presentation) float64 2.0 -6.0 -7.0 4.0 ... 7.0 0.0 -4.0 6.0
   - dy             (presentation) float64 -8.0 -2.0 4.0 -7.0 ... -2.0 -2.0 -1.0
 

In [7]:
su_both_df = pd.read_parquet(os.path.join(lg.experiment_root,'su_both_processed.df'))
su_xent_df = pd.read_parquet(os.path.join(lg.experiment_root,'su_xent_processed.df'))

In [None]:
lg_xent = lg.assemblies[0]
lg_both = lg.assemblies[1]

In [None]:
import os

In [9]:
neural_data = brainscore.get_assembly(name="dicarlo.Majaj2015")
neural_data.load()
stimulus_set = neural_data.attrs['stimulus_set']

  xr_data.set_index(append=True, inplace=True, **coords_d)


In [10]:
def process_dicarlo(assembly,avg_repetition=True,variation=3,tasks=['ty','tz','rxy']):
    stimulus_set = assembly.attrs['stimulus_set']
    stimulus_set['dy_deg'] = stimulus_set.tz*stimulus_set.degrees
    stimulus_set['dx_deg'] = stimulus_set.ty*stimulus_set.degrees
    stimulus_set['dy_px'] = stimulus_set.dy_deg*32
    stimulus_set['dx_px'] = stimulus_set.dx_deg*32
    
    assembly.attrs['stimulus_set'] = stimulus_set
    
    data = assembly.sel(variation=variation)
    groups = ['category_name', 'object_name', 'image_id']+tasks
    if not avg_repetition:
        groups.append('repetition')
        
    data = data.multi_groupby(groups)     # (2)
    data = data.mean(dim='presentation')
    data = data.squeeze('time_bin')    #   (3)
    data.attrs['stimulus_set'] = stimulus_set.query('variation == {}'.format(variation))
    data = data.T
    
    return data

In [11]:
neural_data

<xarray.NeuronRecordingAssembly 'dicarlo.Majaj2015' (neuroid: 296, presentation: 268800, time_bin: 1)>
array([[[ 0.060929],
        [-0.686162],
        ...,
        [-0.968256],
        [ 0.183887]],

       [[-0.725592],
        [ 0.292777],
        ...,
        [ 2.449372],
        [ 0.401197]],

       ...,

       [[ 1.121319],
        [ 1.719423],
        ...,
        [ 0.800551],
        [-0.019874]],

       [[-0.518903],
        [ 0.696196],
        ...,
        [-0.603347],
        [-0.175979]]], dtype=float32)
Coordinates:
  * neuroid          (neuroid) MultiIndex
  - neuroid_id       (neuroid) object 'Chabo_L_M_5_9' ... 'Chabo_L_P_7_6'
  - arr              (neuroid) object 'M' 'M' 'M' 'M' 'M' ... 'P' 'P' 'P' 'P'
  - col              (neuroid) int64 9 9 8 9 8 8 7 7 5 6 ... 8 7 8 6 7 5 7 7 6 6
  - hemisphere       (neuroid) object 'L' 'L' 'L' 'L' 'L' ... 'L' 'L' 'L' 'L'
  - subregion        (neuroid) object 'cIT' 'cIT' 'cIT' 'cIT' ... 'V4' 'V4' 'V4'
  - animal           (neur

In [None]:
DEFAULT_DPRIME_MODE = 'binary'

def dprime(A, B=None, mode=DEFAULT_DPRIME_MODE,\
        max_value=np.inf, min_value=-np.inf,\
        max_ppf_value=np.inf, min_ppf_value=-np.inf,\
        **kwargs):
    """Computes the d-prime sensitivity index of predictions
    from various data formats.  Depending on the choice of
    `mode`, this function can take one of the following format:
    * Binary classification outputs (`mode='binary'`; default)
    * Positive and negative samples (`mode='sample'`)
    * True positive and false positive rate (`mode='rate'`)
    * Confusion matrix (`mode='confusionmat'`)
    Parameters
    ----------
    A, B:
        If `mode` is 'binary' (default):
            A: array, shape = [n_samples],
                True values, interpreted as strictly positive or not
                (i.e. converted to binary).
                Could be in {-1, +1} or {0, 1} or {False, True}.
            B: array, shape = [n_samples],
                Predicted values (real).
        If `mode` is 'sample':
            A: array-like,
                Positive sample values (e.g., raw projection values
                of the positive classifier).
            B: array-like,
                Negative sample values.
        If `mode` is 'rate':
            A: array-like, shape = [n_groupings]
                True positive rates
            B: array-like, shape = [n_groupings]
                False positive rates
        if `mode` is 'confusionmat':
            A: array-like, shape = [n_classes (true), n_classes (pred)]
                Confusion matrix, where the element M_{rc} means
                the number of times when the classifier or subject
                guesses that a test sample in the r-th class
                belongs to the c-th class.
            B: ignored
    mode: {'binary', 'sample', 'rate'}, optional, (default='binary')
        Directs the interpretation of A and B.
    max_value: float, optional (default=np.inf)
        Maximum possible d-prime value.
    min_value: float, optional (default=-np.inf)
        Minimum possible d-prime value.
    max_ppf_value: float, optional (default=np.inf)
        Maximum possible ppf value.
        Used only when mode is 'rate' or 'confusionmat'.
    min_ppf_value: float, optional (default=-np.inf).
        Minimum possible ppf value.
        Used only when mode is 'rate' or 'confusionmat'.
    kwargs: named arguments, optional
        Passed to ``confusion_matrix_stats()`` and used only when `mode`
        is 'confusionmat'.  By assigning ``collation``,
        ``fudge_mode``, ``fudge_factor``, etc. one can
        change the behavior of d-prime computation
        (see ``confusion_matrix_stats()`` for details).
    Returns
    -------
    dp: float or array of shape = [n_groupings]
        A d-prime value or array of d-primes, where each element
        corresponds to each grouping of positives and negatives
        (when `mode` is 'rate' or 'confusionmat')
    References
    ----------
    http://en.wikipedia.org/wiki/D'
    http://en.wikipedia.org/wiki/Confusion_matrix
    """

    # -- basic checks and conversion
    if mode == 'sample':
        pos, neg = np.array(A), np.array(B)

    elif mode == 'binary':
        y_true, y_pred = A, B

        assert len(y_true) == len(y_pred)
        assert np.isfinite(y_true).all()

        y_true = np.array(y_true)
        assert y_true.ndim == 1

        y_pred = np.array(y_pred)
        assert y_pred.ndim == 1

        i_pos = y_true > 0
        i_neg = ~i_pos

        pos = y_pred[i_pos]
        neg = y_pred[i_neg]

    elif mode == 'rate':
        TPR, FPR = np.array(A), np.array(B)
        assert TPR.shape == FPR.shape

    elif mode == 'confusionmat':
        # A: confusion mat
        # row means true classes, col means predicted classes
        P, N, TP, _, FP, _ = confusion_matrix_stats(A, **kwargs)

        TPR = TP / P
        FPR = FP / N

    else:
        raise ValueError('Invalid mode')

    # -- compute d'
    if mode in ['sample', 'binary']:
        assert np.isfinite(pos).all()
        assert np.isfinite(neg).all()

        if pos.size <= 1:
            raise ValueError('Not enough positive samples'\
                    'to estimate the variance')
        if neg.size <= 1:
            raise ValueError('Not enough negative samples'\
                    'to estimate the variance')

        pos_mean = pos.mean()
        neg_mean = neg.mean()
        pos_var = pos.var(ddof=1)
        neg_var = neg.var(ddof=1)

        num = pos_mean - neg_mean
        div = np.sqrt((pos_var + neg_var) / 2.)

        dp = num / div

    else:   # mode is rate or confusionmat
        ppfTPR = norm.ppf(TPR)
        ppfFPR = norm.ppf(FPR)
        ppfTPR = np.clip(ppfTPR, min_ppf_value, max_ppf_value)
        ppfFPR = np.clip(ppfFPR, min_ppf_value, max_ppf_value)
        dp = ppfTPR - ppfFPR

    # from Dan's suggestion about clipping d' values...
    dp = np.clip(dp, min_value, max_value)

    return dp

In [None]:
# pos_samples = .where(lg_both.numeric_label==1,drop=True).values[:,250]
# neg_samples = lg_both.where(lg_both.numeric_label!=1, drop=True).values[:,250]

In [None]:
def xr_exclude_zero_dim(da,neuroid_coord):
    nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
    return da[:,nz_neuroids]

In [None]:
def SUCorrelation(da,neuroid_coord,correlation_vars,exclude_zeros=True):
    if exclude_zeros:
        nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
        da = da[:,nz_neuroids]
    
    correlations = np.empty((len(da[neuroid_coord]),len(correlation_vars)))
    for i,nid in tqdm(enumerate(da[neuroid_coord].values),total=len(da[neuroid_coord])):
        for j,prop in enumerate(correlation_vars):
            n_act = da.sel(**{neuroid_coord:nid}).squeeze()
            r,p = pearsonr(n_act,prop)
            correlations[i,j] = np.abs(r)

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',[v.name for v in correlation_vars])
#     print(neuroid_dim)
    result = Score(correlations,
                       coords=c,
                       dims=('neuroid','task'))
    return result

def SUDprime(da,neuroid_coord='neuroid_id',class_coord='numeric_label',exclude_zeros=True):    
    if exclude_zeros:
            nz_neuroids = da.groupby(neuroid_coord).sum('presentation').values!=0
            da = da[:,nz_neuroids]
    
    def cat_parts(da,class_coord):
        out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
        return np.array(out)

    def dprime_1d(vec,cut=1000):
        return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)
    
#     class_vals = np.unique(da[class_coord].values)
#     parts = [((da[class_coord]==c).values,(da[class_coord]!=c).values) for c in class_vals]
    class_vals = np.unique(da[class_coord].values)

    c_parts = cat_parts(da,class_coord)
    
    dprimes = np.empty((len(da[neuroid_coord]),len(class_vals)))
    for i,nid in tqdm(enumerate(da[neuroid_coord].values),total=dprimes.shape[0]):
#         da_n = da.sel(**{neuroid_coord:nid})
        dpn = np.apply_along_axis(dprime_1d,1,c_parts[:,:,i])
        dprimes[i] = dpn
#         for j,pos_neg in enumerate(parts):
#             pos,neg = pos_neg
#             pos_samples = da_n[pos].values
#             neg_samples = da_n[neg].values
#             dp = dprime(A=pos_samples,B=neg_samples,mode='sample',max_value=1,min_value=-1)
#             dprimes[i,j]=dp

    neuroid_dim = da[neuroid_coord].dims
    c = {coord: (dims, values) for coord, dims, values in walk_coords(da) if dims == neuroid_dim}
    c['task']=('task',['category'])
#     print(neuroid_dim)
    result = Score(dprimes.max(axis=1).reshape(-1,1),
                       coords=c,
                       dims=('neuroid','task'))
    return result

def result_to_df(SUC,corr_var_labels):
    df = SUC.neuroid.to_dataframe().reset_index()
    for label in corr_var_labels:
        df[label]=SUC.sel(task=label).values
    
    return df

In [None]:
def cat_parts(da,class_coord):
    class_vals = np.unique(da[class_coord].values)
    out = [np.concatenate([da[(da[class_coord]==c).values].values,da[(da[class_coord]!=c).values]],axis=0) for c in class_vals]
    return out

def dprime_1d(vec,cut=1000):
    return dprime(A=vec[:cut],B=vec[cut:],mode='sample',max_value=1,min_value=-1)

In [None]:
def gu_SUD(da_sets,neuroid_coord):
    pool = mp.Pool(6)
    results = [pool.apply(SUDprime,args=(da,neuroid_coord)) for da in da_sets]
    pool.close()
    pool.join()
    return results

In [None]:
region_sets = [xr_exclude_zero_dim(lg_both.sel(region=r),'neuroid_id') for r in np.unique(lg_both.region.values)]
both_cat_results = gu_SUD(region_sets,'neuroid_id')
# both_SUdp_score = [SUDprime(rsets,neuroid_coord='neuroid_id',) for rsets in region_sets]

In [None]:
xent_SUdp_score = SUDprime(lg_xent,neuroid_coord='neuroid_id',)

In [None]:
corr_vars_both = [pd.Series(lg_both[v].values,name=v) for v in ['tx','ty']]
corr_both = SUCorrelation(lg_both,neuroid_coord='neuroid_id',correlation_vars=corr_vars_both)
su_both_df = result_to_df(corr_both,['tx','ty'])
su_both_df['norm_ty'] = su_both_df.ty

In [None]:
both_df_dp = pd.concat([result_to_df(res,['category']) for res in both_cat_results]).reset_index().drop(columns=['index'])
# su_both_df['category'] = both_df_dp.category

In [None]:
both_df_dp = both_df_dp.sort_values(by='neuroid_id').reset_index().drop(columns='index')

In [None]:
su_both_df = su_both_df.sort_values(by='neuroid_id').reset_index().drop(columns='index')

In [None]:
su_both_df = pd.concat([su_both_df,both_df_dp[['neuroid_id','category']]],axis=1)
su_both_df.head()

In [None]:
corr_vars_xent = [pd.Series(lg_xent[v].values,name=v) for v in ['tx','ty']]
corr_xent = SUCorrelation(lg_xent,neuroid_coord='neuroid_id',correlation_vars=corr_vars_xent)
su_xent_df = result_to_df(corr_xent,['tx','ty'])
su_xent_df['norm_ty'] = su_xent_df.ty

In [None]:
xent_df_dp = result_to_df(xent_SUdp_score,['category'])
su_xent_df['category'] = xent_df_dp.category

In [None]:
hi_data = process_dicarlo(neural_data,variation=6)

In [None]:
dicarlo_corr_vars = [
    pd.Series(hi_data['ty'],name='tx'),
    pd.Series(hi_data['tz'],name='ty'),
    pd.Series(hi_data['rxy'],name='rxy'),

]

# corr_dicarlo_med = SUCorrelation(med_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_med_corr_vars,exclude_zeros=True)
# dicarlo_med_df = result_to_df(corr_dicarlo_med,['tx','ty','rxy'])
# dicarlo_med_df['variation']=3

corr_dicarlo_hi = SUCorrelation(hi_data,neuroid_coord='neuroid_id',correlation_vars=dicarlo_corr_vars,exclude_zeros=True)
dicarlo_df = result_to_df(corr_dicarlo_hi, ['tx','ty','rxy'])
layer_map = {
    'V4':3,
    'IT':4
}

for reg,layer in zip(['V4','IT'],[3,4]):
    dicarlo_df['layer'] = [layer_map[r] for r in dicarlo_df.region]

In [None]:
dicarlo_SUdp_score = SUDprime(hi_data,neuroid_coord='neuroid_id',class_coord='category_name')

In [None]:
dicarlo_SUdp_df = result_to_df(dicarlo_SUdp_score,['category'])
dicarlo_df['category']=dicarlo_SUdp_df.category

In [None]:
lg.experiment_root

In [None]:
su_both_df['unit_id'] = su_both_df.neuroid_id.values[:,0]

In [None]:
su_xent_df.drop(columns=['neuroid']).to_parquet(os.path.join(lg.experiment_root,'su_xent_processed'))
su_both_df.drop(columns=['neuroid','neuroid_id']).to_parquet(os.path.join(lg.experiment_root,'su_both_processed'))

In [None]:
dicarlo_df.drop(columns='neuroid').to_parquet(os.path.join(lg.experiment_root,'dicarlo.Majaj_processed'))

In [None]:
def plot_bars(y,df,by='region',order=None):
    if order is not None:
        subsets = order
    else:
        subsets = df[by].drop_duplicates().values
        
    plot_scale = 5
    fig,axs = plt.subplots(1,len(subsets),figsize=(plot_scale*len(subsets),plot_scale),sharex=True,sharey=True,
                           subplot_kw={
#                                'xlim':(0.0,0.8),
#                                'ylim':(0.0,0.8)
                           })
    
    for ax,sub in zip(axs,subsets):
        subsets = df[by].drop_duplicates().values
        sub_df = df.query('{} == "{}"'.format(by,sub))
        sns.barplot(x=by,y=y,ax=ax)

def plot_kde(x,y,df,by='region',order=None,xlim=(0.0,0.8),ylim=(0.0,0.8)):
    if order is not None:
        subsets = order
    else:
        subsets = df[by].drop_duplicates().values
        
    plot_scale = 5
    fig,axs = plt.subplots(1,len(subsets),figsize=(plot_scale*len(subsets)*0.8,plot_scale),sharex=True,sharey=True,
                           subplot_kw={
                               'xlim':xlim,
                               'ylim':ylim,
                           })
    
    for ax,sub in zip(axs,subsets):
        sub_df = df.query('{} == "{}"'.format(by,sub))
        sns.kdeplot(sub_df[x],sub_df[y],ax=ax)
        ax.set_title("{}: {}".format(by,sub))
    
    return fig,axs
#         sns.despine(ax)
# plot_bars(y='tx',df=both_df,by='layer',order=np.arange(5))

In [None]:
# set_style()

In [None]:
sns.set(font_scale=2)
sns.set_context('paper')

In [None]:
fig,axs = plt.subplots(2,2,figsize=(8,6),sharex=True,sharey=True,subplot_kw={'ylim':(0,0.35)})

mod_order=np.arange(5)
# mod_order = ['pixel','dense_1','dense_2','dense_3','z_lat','y_lat']

sns.set_context('talk')
properties = ['tx','ty']
for ax_row,df,order in zip(axs,[su_xent_df,su_both_df,],[mod_order,mod_order]): 
    for ax,prop in zip(ax_row,properties):
        sns.barplot(x='layer',y=prop,order=order,data=df,ax=ax,palette='magma')
        sns.despine(ax=ax)
    
    ax_row[1].set_ylabel('')
    ax_row[0].set_ylabel('pearson')
    

for ax in axs[1]:
    ax.set_xticklabels(['pixel','1','2','3','4'])
for ax,prop in zip(axs[0],properties):
    ax.get_xaxis().set_visible(False)
    ax.set_title(prop)

plt.tight_layout()

In [None]:
properties = ['category','tx']
fig,axs = plt.subplots(2,len(properties),figsize=(len(properties)*4,6),sharex=True,sharey=True)

mod_order=np.arange(5)
# mod_order = ['pixel','dense_1','dense_2','dense_3','z_lat','y_lat']

sns.set_context('talk')
for ax_row,df,order in zip(axs,[su_xent_df,su_both_df,],[mod_order,mod_order]): 
    for ax,prop in zip(ax_row,properties):
        sns.barplot(x='layer',y=prop,order=order,data=df,ax=ax,palette='magma')
        sns.despine(ax=ax)
    
    ax_row[1].set_ylabel('')
    ax_row[0].set_ylabel('d\'')
    

for ax in axs[1]:
    ax.set_xticklabels(['pixel','1','2','3','4'])

for ax,prop in zip(axs[0],properties):
    ax.get_xaxis().set_visible(False)
    ax.set_title(prop)

plt.tight_layout()

In [None]:
order=[0,1,2,3,4,5]
fig,ax = plt.subplots(1,1,figsize=(5,4.5))
sns.set_context('talk')
sns.boxplot(x='layer',y='tx',order=order,data=dicarlo_df,ax=ax,palette='magma')
ax.set_xticklabels(['pixel','V1','?','V4','IT',''])
sns.despine(ax=ax)
plt.tight_layout()

In [None]:

# [['tx','ty','rxy','layer','region']]

In [None]:
def topn_su_decode(df,n,props,**kwargs):
    order=[0,1,2,3,4]
    fig,axs = plt.subplots(1,len(props),figsize=(len(props)*4,3),sharey=True,**kwargs)
    sns.set_context('talk')
    for ax,prop in zip(axs,props):
        df_topn = pd.concat([df.query('layer == {}'.format(l)).nlargest(n,prop) for l in [3,4]])
        sns.barplot(x='layer',y=prop,order=order,data=df_topn,ax=ax,palette='magma')
        ax.set_xticklabels(['pixel','V1','?','V4','IT'])
        ax.set_ylabel('d\'')
        sns.despine(ax=ax)
    plt.tight_layout()

    
topn_su_decode(dicarlo_df,n=15,props=['tx','ty'],subplot_kw={'ylim':(0,0.4)})

In [None]:
topn_su_decode(dicarlo_df,n=15,props=['category','tx'],subplot_kw={'ylim':(0,1)})

In [None]:
fig,axs = plt.subplots(1,2,figsize=(8,4),sharey=True,sharex=True)

mod_order=np.arange(5)

sns.set_context('talk')
for ax,df,order in zip(axs,[su_xent_df,su_both_df,],[mod_order,mod_order]): 
    sns.barplot(x='layer',y='tx',order=order,data=df,ax=ax)
axs[0].set_xticklabels(['pixel','1','2','3','4'])
axs[1].get_yaxis().set_visible(False)
plt.tight_layout()

In [None]:
sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

In [None]:
# sns.set_style('whitegrid')
sns.set_context('paper')
sns.set(font_scale=2)

fig,axs = plot_kde('tx','ty',su_xent_df,by='layer',order=np.arange(5),)
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

#         ax.get_xaxis().set_visible(False)
    




In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_xent_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5),xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('ty','category',su_both_df,by='layer',order=np.arange(5))
plt.tight_layout()

In [None]:
sns.set_context('paper')
sns.set(font_scale=2)
fig,axs = plot_kde('ty','category',dicarlo_df,by='region',order=['V4','IT'],xlim=(0,1.1),ylim=(0,1.1))
plt.tight_layout()
sns.despine(fig=fig)
for i,ax in enumerate(axs):
    pass
    ax.set_ylabel('')
    ax.set_xlabel('')

In [None]:
plot_kde('tx','category',dicarlo_df,by='region',order=['V4','IT'])


In [None]:
sns.scatterplot(x='tx',y='ty',data=dicarlo_df.query('region == "IT"'))
plt.ylim(0,0.5)
plt.xlim(0,0.5)

In [None]:
class MURegressor(object):
    def __init__(self,da,train_frac=0.8,n_splits=5,n_units=None,estimator=Ridge):
        if n_units is not None:
            self.neuroid_idxs = [np.array([random.randrange(len(da.neuroid_id)) for _ in range(n_units)]) for _ in range(n_splits)]
        
        self.original_data = da
        self.train_frac = train_frac
        self.n_splits = n_splits
        
        splits = [split_assembly(self.original_data[:,n_idxs]) for n_idxs in tqdm(self.neuroid_idxs,total=n_splits,desc='CV-splitting')]
        self.train = [tr for tr,te in splits]
        self.test = [te for tr,te in splits]
        
        
        self.estimators = [estimator() for _ in range(n_splits)]
        
    def fit(self,y_coord):
        # Get Training data
        for mod,train in tqdm(zip(self.estimators,self.train),total=len(self.train),desc='fitting'):
#             print(train)
            mod.fit(X=train.values,y=train[y_coord])
    
        return self
    
    def predict(self,X=None):
        if X is not None:
            return [e.predict(X) for e in self.estimators]
        else:
            return [e.predict(te.values) for e,te in zip(self.estimators,self.test)]
        
    def score(self,y_coord):
        return [e.score(te.values,te[y_coord].values) for e,te in zip(self.estimators,self.test)]
    
def stratified_regressors(data, filt='region',n_units=126,y_coords=['ty','tz'],task_names=None,estimator=Ridge):
    subsets = np.unique(data[filt].values)
    if task_names is None:
        task_names = y_coords
    dfs = []
    for y,task in zip(y_coords,task_names):
        print('regressing {}...'.format(y))
        regressors = {k:MURegressor(data.sel(**{filt:k}),n_units=n_units,estimator=Ridge).fit(y_coord=y) for k in subsets}
        df = pd.DataFrame.from_records({k:v.score(y_coord=y) for k,v in regressors.items()})
        df = df.melt(var_name='region',value_name='performance')
        df['task']=task
        dfs.append(df)
    
    return pd.concat(dfs)

In [None]:
properties = ['tx','ty']
mu_both_df = stratified_regressors(lg_both,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_both_df)

In [None]:
mu_xent_df = stratified_regressors(lg_xent,filt='layer',y_coords=properties,n_units=50)

In [None]:
sns.barplot(x='task',y='performance',hue='region',data=mu_xent_df)

In [None]:
plot_kde(x='tx',y='ty',df=both_df,by='layer',order=np.arange(5))

In [None]:
plot_kde(x='tx',y='ty',df=xent_df,by='layer',order=np.arange(5))