In [1]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.decomposition import FastICA

from dcmr import task
from dcmr import figures
from dcmr import figures
from dcmr import figures
from dcmr import framework
from cymr import cmr
from psifr import fr

cfr_dir = Path(os.environ['CFR_RESULTS'])
data_file = cfr_dir / 'cfr_eeg_mixed.csv'
patterns_file = cfr_dir / 'cfr_patterns.hdf5'

model_name = 'cmrsit_sem-split_cue-focused_dis-cat_sub-list'

fit_dir = Path(os.environ['CFR_FITS']) / 'switchboard4' / model_name

fig_dir = Path(os.environ['CFR_FIGURES']) / 'sim3' / 'schematic'
fig_dir.mkdir(exist_ok=True)
figures.set_style()

In [2]:
param_file = fit_dir / 'fit.csv'
subj_param = framework.read_fit_param(param_file)
config_file = fit_dir / 'parameters.json'
param_def = cmr.read_config(config_file)
patterns = cmr.load_patterns(patterns_file)

In [3]:
{subj: (p['B_enc_cat'], p['B_disrupt_cat'], p['B_enc_use']) for subj, p in subj_param.items()}

{1: (0.007463374686649, 0.0197009092327944, 0.8986412753055042),
 2: (5.858458719298223e-05, 0.0599771238998294, 0.7265438455179458),
 3: (0.3406210032997293, 0.8013227918429302, 0.0013423840829961),
 5: (0.150228513088209, 0.5237830202918696, 0.78052750629874),
 8: (0.9814841575535368, 0.1469672687172636, 0.8566819605737734),
 11: (0.0004204475705464, 0.0394325995624002, 0.8075434283194071),
 16: (0.3856232016060198, 0.999835960504424, 0.6733980310404938),
 18: (0.5050951632156815, 1.0, 0.7589211609601685),
 22: (0.2753039095547667, 0.6893959434431676, 0.8177310108256942),
 23: (0.3021920287057242, 0.9952418030305068, 0.4663443480200603),
 24: (0.0, 0.0825042835385697, 0.8850167244539123),
 25: (0.4311479169204801, 1.0, 0.0517755893685623),
 27: (0.2025961443532145, 0.1253310141703842, 0.7643837331879996),
 28: (0.3361882713410586, 0.9433728363630852, 0.9727700678551572),
 29: (0.5949231018057072, 8.043312983387437e-09, 0.9681270277045614),
 31: (0.0306197632898386, 0.0074144476520525

In [4]:
data = task.read_study_recall(data_file)
subjects = []
lists = []
for s in data['subject'].unique():
    subj_data = fr.filter_data(data, subjects=s)
    for l in subj_data['list'].unique():
        study = fr.filter_data(subj_data, lists=l, trial_type='study').iloc[:9]
        n = study['category'].value_counts()
        if (n == 3).all():
            subjects.append(s)
            lists.append(l)

In [5]:
ind = np.nonzero(np.array(subjects) == 22)[0][0]
ex_list = lists[ind]
ex_subj = subjects[ind]
study = (
    fr.filter_data(data, subjects=ex_subj, lists=ex_list, trial_type='study')
    .reset_index()
    .iloc[:9]
)
study

Unnamed: 0,index,subject,list,position,trial_type,item,item_index,session,list_type,category,...,response_time,list_category,block,block_pos,block_len,n_block,curr,prev,next,base
0,8868,22,3,1,study,CRADLE,566,1,mixed,obj,...,2.135,mixed,1,1,3,9,obj,,loc,
1,8869,22,3,2,study,DESK,572,1,mixed,obj,...,1.512,mixed,1,2,3,9,obj,,loc,
2,8870,22,3,3,study,BROOM,547,1,mixed,obj,...,1.198,mixed,1,3,3,9,obj,,loc,
3,8871,22,3,4,study,YANKEE STADIUM,508,1,mixed,loc,...,1.969,mixed,2,1,3,9,loc,obj,cel,cel
4,8872,22,3,5,study,TRITON FOUNTAIN,484,1,mixed,loc,...,1.296,mixed,2,2,3,9,loc,obj,cel,cel
5,8873,22,3,6,study,ARC DE TRIOMPHE,266,1,mixed,loc,...,1.635,mixed,2,3,3,9,loc,obj,cel,cel
6,8874,22,3,7,study,RAY CHARLES,208,1,mixed,cel,...,1.368,mixed,3,1,3,9,cel,loc,loc,obj
7,8875,22,3,8,study,ADAM SANDLER,1,1,mixed,cel,...,1.854,mixed,3,2,3,9,cel,loc,loc,obj
8,8876,22,3,9,study,FRAN DRESCHER,88,1,mixed,cel,...,1.574,mixed,3,3,3,9,cel,loc,loc,obj


In [6]:
item_ind = study['item_index']
items = patterns['items'][item_ind]

In [7]:
model = cmr.CMR()
state = model.record(
    study, 
    {}, 
    subj_param, 
    param_def=param_def, 
    patterns=patterns, 
    include=['c'],
    study_keys=['block', 'block_pos']
)
net = state[0]

In [8]:
sublayer_states = {}
for sublayer in ['loc', 'cat', 'use']:
    c_array = np.vstack([s.c[net.get_slice('c', sublayer, 'item')] for s in state])
    sublayer_states[sublayer] = c_array

In [9]:
n = len(items)
m = 9
ica = FastICA(n_components=m, random_state=42)
raw_vectors = patterns['vector']['use']
t = ica.fit(raw_vectors)
s = sublayer_states['use']
m_use = np.clip(stats.zscore(t.transform(s), axis=0), -2, 2)
# m_use = np.clip(stats.zscore(ica.fit_transform(s), axis=0), -1, 1)



In [10]:
matrix = {'loc': [], 'cat': [], 'use': []}
for i in range(9):
    m = sublayer_states['loc'][i, item_ind].reshape(3, 3)
    matrix['loc'].append(m)

    m = np.zeros((3, 3))
    s = sublayer_states['cat'][i]
    for j in range(3):
        m[j, :] = s[j]
    matrix['cat'].append(m)

    matrix['use'].append(m_use[i].reshape(3, 3))

In [11]:
def print_pattern(x, fig_file):
    mat = x.reshape(3, 3)
    fig = plt.figure(frameon=False, figsize=(2, 2))
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    h = ax.matshow(mat, cmap="gray")
    h.set_clim(0, 1)
    plt.savefig(fig_file, pad_inches=0)
    plt.close(fig)

In [12]:
for i, m in enumerate(matrix['loc']):
    print_pattern(m, fig_dir / f'loc_{items[i]}.png')

In [13]:
for i, m in enumerate(matrix['cat']):
    print_pattern(m, fig_dir / f'cat_{items[i]}.png')

In [14]:
for i, m in enumerate(matrix['use']):
    print_pattern(m, fig_dir / f'use_{items[i]}.png')

In [15]:
%load_ext watermark
%watermark -v -iv

Python implementation: CPython
Python version       : 3.13.6
IPython version      : 9.9.0

cymr      : 0.14.3
dcmr      : 1.0.0a0
matplotlib: 3.10.8
numpy     : 2.4.1
psifr     : 0.10.1
scipy     : 1.17.0
sklearn   : 1.8.0

