# Single-Task CogPonder: Stroop

This notebook implements a basic single-task CogPonder agents that learns Stroop task by imitating human participants.

In [None]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output

import torch
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from src.cogponder import CogPonderModel
from src.cogponder.data import StroopSRODataset, CogPonderDataModule
from pathlib import Path

In [None]:
# Parameters

# number of maximum epochs to train (early stopping will be applied)
# early stopping patience is 10% of max_epochs (min 10 epochs)
max_epochs = 3000
batch_size = 96 * 5
step_duration = 10
n_subjects = 5

In [None]:
# Data

print(f'Loading Stroop dataset... ', end='')

dataset = StroopSRODataset(n_subjects=n_subjects,
                           step_duration=step_duration,
                           shuffle_subjects=False,
                           non_decision_time='auto')

datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8)
datamodule.prepare_data()

# determine some parameters from data
n_subjects = datamodule.dataset[:][0].unique().size(0)
n_contexts = datamodule.dataset[:][2].unique().size(0)
n_features = datamodule.dataset[:][3].size(-1)
max_response_step = datamodule.dataset[:][5].max().int().item() + 1
n_outputs = datamodule.dataset[:][6].unique().shape[0]

configs = {
    'inputs_dim': n_features,
    'outputs_dim': n_outputs,
    'embeddings_dim': 4,
    'time_loss_beta': 1.,
    'learning_rate': 1e-3,
    'max_response_step': max_response_step,
    'n_contexts': n_contexts,
    'n_subjects': n_subjects,
    'subject_embeddings_dim': 4,
    'task': 'stroop',
    'operator_type': 'spatiotemporal'
}

print('Done!')

In [None]:
# Experiment

model = CogPonderModel(**configs)

# model = torch.compile(model)

# Trainer
trainer = pl.Trainer(
    max_epochs=max_epochs,
    min_epochs=200,
    accelerator='auto',
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        # EarlyStopping(monitor='val/total_loss', mode='min'),
    ])

# Fit and evaluate the model
trainer.fit(model,
            ckpt_path='models/checkpoints/stroop/cogponder_5subjects_2554epochs.ckpt',
            datamodule=datamodule)

# save checkpoint
ckpt_path = f'models/checkpoints/stroop/cogponder_{n_subjects}subjects_{trainer.current_epoch}epochs.ckpt'
trainer.save_checkpoint(ckpt_path)

clear_output()

print('Saved checkpoint to:', ckpt_path)