In [34]:
# Imports

%reload_ext autoreload
%autoreload 3

import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

from src.cogponder import CogPonderModel
from src.cogponder.data import CogPonderDataModule, StroopSRODataset, NBackSRODataset

In [35]:
# Parameters

step_duration = 20  # in ms
batch_size = 512
max_epochs = 1000

In [211]:
# Data

dataset = NBackSRODataset(n_subjects=-1, n_back=2, step_duration=step_duration)
datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8, train_ratio=.5)

In [None]:
# Experiment

configs = {
    'inputs_dim': 1,
    'outputs_dim': torch.unique(dataset[:][4]).size(0),
    'embeddings_dim': 8,
    'response_loss_beta': 1.,
    'time_loss_beta': 10.,
    'learning_rate': 1e-2,
    'max_response_step': 150, # dataset[:][5].max().int().item() + 10,
    'n_contexts': torch.unique(dataset[:][2]).size(0),
    'n_subjects': torch.unique(dataset[:][1]).size(0),
    'subject_embeddings_dim': 2,
    'task': 'nback'
}

model = CogPonderModel(**configs)

trainer = pl.Trainer(max_epochs=max_epochs, accelerator='cpu', log_every_n_steps=2)

trainer.fit(model, datamodule=datamodule)