In [1]:
import numpy as np
import pandas as pd
import xarray as xr

In [2]:
import mkgu
from mkgu.metrics.rdm import RSA

In [3]:
hvm = mkgu.get_assembly(name="dicarlo.Majaj2015")
hvm

<xarray.NeuronRecordingAssembly (neuroid: 296, presentation: 268800, time_bin: 1)>
array([[[ 0.060929],
        [-0.686162],
        ..., 
        [-0.968256],
        [ 0.183887]],

       [[-0.725592],
        [ 0.292777],
        ..., 
        [ 2.449372],
        [ 0.401197]],

       ..., 
       [[ 1.121319],
        [ 1.719423],
        ..., 
        [ 0.800551],
        [-0.019874]],

       [[-0.518903],
        [ 0.696196],
        ..., 
        [-0.603347],
        [-0.175979]]], dtype=float32)
Coordinates:
  * neuroid          (neuroid) MultiIndex
  - neuroid_id       (neuroid) object 'Chabo_L_M_5_9' 'Chabo_L_M_6_9' ...
  - arr              (neuroid) object 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' ...
  - col              (neuroid) int64 9 9 8 9 8 8 7 7 5 6 4 9 9 9 9 9 8 7 9 6 ...
  - hemisphere       (neuroid) object 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' ...
  - subregion        (neuroid) object 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' ...
  - animal           (neuroid) object 'Chabo'

In [4]:
hvm.load() # This loads all data into memory instead of lazily fetching from disk

<xarray.NeuronRecordingAssembly (neuroid: 296, presentation: 268800, time_bin: 1)>
array([[[ 0.060929],
        [-0.686162],
        ..., 
        [-0.968256],
        [ 0.183887]],

       [[-0.725592],
        [ 0.292777],
        ..., 
        [ 2.449372],
        [ 0.401197]],

       ..., 
       [[ 1.121319],
        [ 1.719423],
        ..., 
        [ 0.800551],
        [-0.019874]],

       [[-0.518903],
        [ 0.696196],
        ..., 
        [-0.603347],
        [-0.175979]]], dtype=float32)
Coordinates:
  * neuroid          (neuroid) MultiIndex
  - neuroid_id       (neuroid) object 'Chabo_L_M_5_9' 'Chabo_L_M_6_9' ...
  - arr              (neuroid) object 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' ...
  - col              (neuroid) int64 9 9 8 9 8 8 7 7 5 6 4 9 9 9 9 9 8 7 9 6 ...
  - hemisphere       (neuroid) object 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' ...
  - subregion        (neuroid) object 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' ...
  - animal           (neuroid) object 'Chabo'

In [5]:
hvm_it_v6 = hvm.sel(variation=6).sel(region="IT")
hvm_it_v6

<xarray.NeuronRecordingAssembly (neuroid: 168, presentation: 120320, time_bin: 1)>
array([[[ 0.121288],
        [-1.784021],
        ..., 
        [-0.968256],
        [ 0.183887]],

       [[ 0.512693],
        [ 0.512693],
        ..., 
        [ 2.449372],
        [ 0.401197]],

       ..., 
       [[-0.384445],
        [-0.384445],
        ..., 
        [-0.377575],
        [-0.393541]],

       [[-1.358922],
        [ 1.665531],
        ..., 
        [ 0.208243],
        [-0.358707]]], dtype=float32)
Coordinates:
  * neuroid          (neuroid) MultiIndex
  - neuroid_id       (neuroid) object 'Chabo_L_M_5_9' 'Chabo_L_M_6_9' ...
  - arr              (neuroid) object 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' ...
  - col              (neuroid) int64 9 9 8 9 8 8 7 7 5 6 4 8 8 6 7 5 7 7 6 6 ...
  - hemisphere       (neuroid) object 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' ...
  - subregion        (neuroid) object 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' 'cIT' ...
  - animal           (neuroid) object 'Chabo'

In [6]:
dim = "presentation"
group_coord_names = ["category_name", "object_name"]
hvm_it_v6_grp = hvm_it_v6.multi_groupby(group_coord_names)
hvm_it_v6_obj = hvm_it_v6_grp.mean(dim=dim).squeeze("time_bin").T
assert hvm_it_v6_obj.shape == (64, 168)
hvm_it_v6_obj

<xarray.NeuronRecordingAssembly (presentation: 64, neuroid: 168)>
array([[ 0.087381, -0.035234,  0.16607 , ..., -0.083298, -0.04096 ,  0.009889],
       [ 0.10009 ,  0.07415 ,  0.136218, ...,  0.049036, -0.040681, -0.049637],
       [ 0.196117,  0.095425, -0.012496, ...,  0.088381, -0.079695,  0.023809],
       ..., 
       [ 0.03505 ,  0.009078, -0.098099, ..., -0.030416, -0.081772, -0.185693],
       [ 0.065006,  0.072046, -0.032077, ..., -0.127431, -0.025176, -0.034318],
       [ 0.102832,  0.102089,  0.089803, ..., -0.055868, -0.108431, -0.083379]], dtype=float32)
Coordinates:
  * neuroid        (neuroid) MultiIndex
  - neuroid_id     (neuroid) object 'Chabo_L_M_5_9' 'Chabo_L_M_6_9' ...
  - arr            (neuroid) object 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' 'M' ...
  - col            (neuroid) int64 9 9 8 9 8 8 7 7 5 6 4 8 8 6 7 5 7 7 6 6 5 ...
  - hemisphere     (neuroid) object 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' 'L' ...
  - subregion      (neuroid) object 'cIT' 'cIT' 'cIT' 'cIT'

In [7]:
rsa_characterization = RSA()
rsa = rsa_characterization(hvm_it_v6_obj)
assert list(rsa.shape) == [64, 64]
rsa

<xarray.DataAssembly (presentation: 64)>
array([[ 1.      ,  0.53554 ,  0.687936, ...,  0.47855 ,  0.618492,  0.551413],
       [ 0.53554 ,  1.      ,  0.438187, ...,  0.303748,  0.413138,  0.359837],
       [ 0.687936,  0.438187,  1.      , ...,  0.348462,  0.232529,  0.262504],
       ..., 
       [ 0.47855 ,  0.303748,  0.348462, ...,  1.      ,  0.332639,  0.275846],
       [ 0.618492,  0.413138,  0.232529, ...,  0.332639,  1.      ,  0.549457],
       [ 0.551413,  0.359837,  0.262504, ...,  0.275846,  0.549457,  1.      ]])
Coordinates:
    time_bin       object (70, 170)
  * presentation   (presentation) MultiIndex
  - category_name  (presentation) object 'Animals' 'Animals' 'Animals' ...
  - object_name    (presentation) object 'bear' 'cow' 'dog' 'elephant' ...