In [None]:
# Imports

%reload_ext autoreload
%autoreload 3

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

from src.cogponder import CogPonderModel
from src.cogponder.data import CogPonderDataModule, StroopSRODataset, NBackSRODataset

In [None]:
# Parameters

step_duration = 20  # in ms
batch_size = 512
max_epochs = 1000

In [None]:
# Data

nback_dataset = NBackSRODataset(n_subjects=-1, n_back=2, step_duration=step_duration)
stroop_dataset = StroopSRODataset(n_subjects=-1, step_duration=step_duration)

datamodule = CogPonderDataModule({
    # 'nback': nback_dataset,
    'stroop': stroop_dataset
}, batch_size=batch_size, num_workers=8, train_ratio=.5)

datamodule.prepare_data()

In [None]:
# Experiment


configs = {
    'inputs_dim': 1,
    'outputs_dim': datamodule.dataset[:][4].unique().size(0),
    'embeddings_dim': 8,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': 150, # dataset[:][5].max().int().item() + 10,
    'n_contexts': datamodule.dataset[:][2].unique().size(0),
    'n_subjects': datamodule.dataset[:][0].unique().size(0),
    'subject_embeddings_dim': 2,
    'task': 'stroop',
    'operator_type': 'simple',  # 'spatiotemporal',
}

model = CogPonderModel(**configs)

trainer = pl.Trainer(max_epochs=max_epochs, accelerator='cpu', log_every_n_steps=2)

trainer.fit(model, datamodule=datamodule)