In [1]:
%reload_ext autoreload
%autoreload 3

import torch
from torch import nn

from src.cogponder import CogPonderModel

In [6]:
batch_size = 32
n_trials = 64
inputs_dim = 10
outputs_dim = 2

X = torch.rand(10, 10)
contexts = torch.randint(0, 1, (10,))  # single context

configs = {
    'inputs_dim': inputs_dim,
    'outputs_dim': outputs_dim,
    'embeddings_dim': 10,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': 100,
    'n_contexts': 1,
    'task': 'nback'
}

model = CogPonderModel(**configs)

y_steps, p_steps, halt_steps = model(X, contexts)


# SRO Data

In [186]:
%reload_ext autoreload
%autoreload 3

import pandas as pd
import numpy as np

data_path = 'data/Self_Regulation_Ontology/stroop.csv.gz'
n_subjects = -1  # all
step_duration = 10  # in ms


data = pd.read_csv(
    data_path,
    index_col=0)

def preprocess(data):

    worker_ids = []

    if len(worker_ids) == 0:
        worker_ids = data['worker_id'].unique()[:-1]

    # select only worker_ids and test trials
    data = data.query('worker_id in @worker_ids and exp_stage == "test"').copy()

    data.sort_index(inplace=True)
    data['trial_index'] = data.groupby('worker_id').cumcount() + 1

    sro_conditions = {'incongruent': 0, 'congruent': 1}
    sro_colors = {-1: 'timeout', 66: 'blue', 71: 'green', 82: 'red'}

    # map key_press to color names
    data['key_press'] = data['key_press'].map(sro_colors)

    # set categories
    data['worker_id'] = data['worker_id'].astype('category')
    data['condition'] = data['condition'].astype('category').cat.set_categories(sro_conditions.keys(), ordered=True)
    data['key_press'] = data['key_press'].astype('category').cat.set_categories(sro_colors.values(), ordered=True)
    data['stim_color'] = data['stim_color'].astype('category').cat.set_categories(sro_colors.values(), ordered=True)
    data['stim_word'] = data['stim_word'].astype('category').cat.set_categories(sro_colors.values(), ordered=True)

    # encode categorical variables
    data['worker_id'] = data['worker_id'].cat.codes.astype('int64') + 1  # start at 1
    data['condition'] = data['condition'].cat.codes.astype('int8')
    data['key_press'] = data['key_press'].cat.codes.astype('int8')
    data['stim_color'] = data['stim_color'].cat.codes.astype('int8')
    data['stim_word'] = data['stim_word'].cat.codes.astype('int8')
    data['correct'] = data['correct'].astype('int8')

    # compute response steps
    data['response_step'] = data['rt'] // step_duration
    data['response_step'] = data['response_step'].apply(np.floor).astype('int')

    return data


data = data.pipe(preprocess)

#TODO handle missing RTs (replace with max trial duration?)

mappings = {
    'trial_ids': ['trial_index'],
    'subject_ids':['worker_id'] ,
    'stimuli': ['stim_color', 'stim_word'],
    'contexts': ['condition'],
    'responses': ['key_press'],
    'response_steps': ['response_step'],  # requires post-processing to be in steps
    'corrects': ['correct'],
}

dimensions = {
    'trial_ids': ('observation'),
    'subject_ids': ('observation'),
    'stimuli': ('observation', 'stimulus_modality'),
    'contexts': ('observation'),
    'responses': ('observation'),
    'response_steps': ('observation'),
    'corrects': ('observation')
}

import xarray as xr

mapped = xr.Dataset()

for k, v in mappings.items():
    mapped[k] = (dimensions[k], data[v].values.squeeze())

mapped