In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import pickle as pkl

In [82]:
df = pd.read_csv('stimuli/elli_pairs.csv')
# df = df[df['run'] == 'even']

df['condition'] = df[['category', 'class']].agg('_'.join, axis=1)

items = list(zip(df['pair'], df['condition'], df['run'], df.index))
items = sorted(items, key = lambda x: x[1])

conditions = sorted(set([i[1] for i in items]))

In [84]:
df = df.loc[[item[3] for item in items]].reset_index(drop=True)

In [96]:
with open('data/responses/5-5-2.5-48-mean-0/elli.pkl', 'rb') as f:
    data = pkl.load(f)

data = np.concatenate([data[condition] for condition in conditions])
data = data.reshape(-1, 24 * 784)

df['activations'] = [pair for pair in data]

In [106]:
df

Unnamed: 0,run_type,category,subcategory,trials,activations
0,odd,noun,bird,hawk_sparrow,"[-0.02535043, -0.07031492, -0.16160995, -0.028..."
1,odd,noun,bird,crane_duck,"[-0.039050035, -0.08379563, -0.13613366, -0.02..."
2,odd,noun,bird,goose_peacock,"[-0.026041985, -0.06813426, -0.14598678, -0.02..."
3,odd,noun,bird,owl_fowl,"[-0.02232096, -0.09358954, -0.11189223, -0.003..."
4,odd,noun,bird,goose_owl,"[-0.038984284, -0.086307384, -0.12093766, -0.0..."
...,...,...,...,...,...
571,even,verb,sound,beep_crackle,"[0.056000136, -0.060103483, -0.10159906, 0.028..."
572,even,verb,sound,chime_beep,"[0.06532799, -0.08444446, -0.14366214, 0.01655..."
573,even,verb,sound,fizzle_crackle,"[0.056060184, -0.056188602, -0.12558281, 0.037..."
574,even,verb,sound,jingle_plop,"[0.050238278, -0.06823565, -0.1262034, 0.00627..."


In [98]:
def format_pair(row):
    if row['class'] == 'noun':
        return row['pair'].replace('the ', '').replace(', ', '_')
    elif row['class'] == 'verb':
        return row['pair'].replace('to ', '').replace(', ', '_')
    return row['pair']

df['pair'] = df.apply(format_pair, axis=1)

In [101]:
df = df.rename(columns={
    'category': 'subcategory',
    'class': 'category',
    'pair': 'trials',
    'run': 'run_type'
})

df = df.drop(['block', 'condition'], axis=1)

In [147]:
df['category'] = df['category'].replace({'noun': 'nouns', 'verb': 'verbs'}, regex=False)

Unnamed: 0,run_type,category,subcategory,trials,activations
0,odd,noun,bird,hawk_sparrow,"[-0.02535043, -0.07031492, -0.16160995, -0.028..."
1,odd,noun,bird,crane_duck,"[-0.039050035, -0.08379563, -0.13613366, -0.02..."
2,odd,noun,bird,goose_peacock,"[-0.026041985, -0.06813426, -0.14598678, -0.02..."
3,odd,noun,bird,owl_fowl,"[-0.02232096, -0.09358954, -0.11189223, -0.003..."
4,odd,noun,bird,goose_owl,"[-0.038984284, -0.086307384, -0.12093766, -0.0..."


In [151]:
# Stack activations into a single 2D array (576 trials, 1024 model units)
activations = np.vstack(df['activations'].values)

# Create the xarray DataArray
xarr = xr.Dataset(
    {
        "activations": (["trials", "model_units"], np.stack(df['activations'].values))
    },
    coords={
        'trials': df['trials'].values,
        'category': ('trials', df['category'].values),
        'subcategory': ('trials', df['subcategory'].values),
        'run_type': ('trials', df['run_type'].values),
        'model_units': np.arange(activations.shape[1])
    }
)

In [152]:
print(xarr)

<xarray.Dataset> Size: 44MB
Dimensions:      (trials: 576, model_units: 18816)
Coordinates:
  * trials       (trials) object 5kB 'hawk_sparrow' ... 'sizzle_ring'
    category     (trials) object 5kB 'noun' 'noun' 'noun' ... 'verb' 'verb'
    subcategory  (trials) object 5kB 'bird' 'bird' 'bird' ... 'sound' 'sound'
    run_type     (trials) object 5kB 'odd' 'odd' 'odd' ... 'even' 'even' 'even'
  * model_units  (model_units) int64 151kB 0 1 2 3 4 ... 18812 18813 18814 18815
Data variables:
    activations  (trials, model_units) float32 43MB -0.02535 -0.07031 ... -1.088


In [159]:
model_noun_selective_units = xarr['activations'].where(
    (xarr['run_type'] == 'even') &
    (xarr['category'] == 'noun'), drop=True
)

In [160]:
print(model_noun_selective_units)

<xarray.DataArray 'activations' (trials: 144, model_units: 18816)> Size: 11MB
array([[-0.00296812, -0.03881348, -0.15995902, ...,  2.1688273 ,
        -0.6697358 ,  0.38221067],
       [-0.02550984, -0.05017655, -0.14002888, ...,  2.2438884 ,
        -1.3391901 , -0.7410905 ],
       [-0.02763878, -0.06981722, -0.16948359, ...,  1.5451136 ,
        -2.301042  , -1.3984326 ],
       ...,
       [-0.01749309, -0.04262029, -0.1317557 , ...,  2.9346542 ,
        -2.2368298 , -1.8565958 ],
       [-0.02165882, -0.03462782, -0.15250704, ...,  1.3295065 ,
        -2.1154113 , -1.7250696 ],
       [-0.0302141 , -0.0485459 , -0.13546866, ...,  1.5614494 ,
        -2.499268  , -2.539213  ]], dtype=float32)
Coordinates:
  * trials       (trials) object 1kB 'ostrich_falcon' ... 'boulder_canyon'
    category     (trials) object 1kB 'noun' 'noun' 'noun' ... 'noun' 'noun'
    subcategory  (trials) object 1kB 'bird' 'bird' ... 'natural' 'natural'
    run_type     (trials) object 1kB 'even' 'even' 'eve