In [2]:
# Imports

%reload_ext autoreload
%autoreload 3

import numpy as np
import torch
from torch import nn
import pytorch_lightning as pl

from src.cogponder import CogPonderModel
from src.cogponder.data import CogPonderDataModule
from src.cogponder.data.stroop_sro import StroopSRODataset


In [104]:
dataset = StroopSRODataset()

from torch.utils.data import TensorDataset, Subset
_data = TensorDataset(*dataset[:])

n_subjects = np.unique(dataset._data['subject_ids'].data).shape[0]

from sklearn.model_selection import train_test_split
train, test = train_test_split(range(n_subjects), test_size=0.2)

# %timeit  train_mask = dataset._data['subject_ids'].isin(train)

# %timeit dataset._data.isel(observation=train_mask)
_data[50015]

(tensor(95),
 tensor(520),
 tensor(0),
 tensor([3., 2.]),
 tensor(2),
 tensor(63),
 tensor(0))

In [None]:
# Experiment

step_duration = 10  # in ms
batch_size = 512

dataset = StroopSRODataset(response_step_interval=step_duration)

datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8)

configs = {
    'inputs_dim': dataset._data['stimuli'].shape[1],
    'outputs_dim': np.unique(dataset._data['responses'].data).shape[0],
    'embeddings_dim': 10,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': dataset._data['response_steps'].data.max() + 10,
    'n_contexts': np.unique(dataset._data['contexts'].data).shape[0],
    'n_subjects': np.unique(dataset._data['subject_ids'].data).shape[0],
    'subject_embeddings_dim': 2,
    'task': 'nback'
}

model = CogPonderModel(**configs)

trainer = pl.Trainer(max_epochs=1000, accelerator='cpu', log_every_n_steps=2)

trainer.fit(model, datamodule=datamodule)