# Single-Task CogPonder: Stroop

This notebook implements a basic single-task CogPonder agents that learns Stroop task by imitating human participants.

In [42]:
%reload_ext autoreload
%autoreload 3

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

import torch
import pytorch_lightning as pl
import pandas as pd
from pytorch_lightning.callbacks import EarlyStopping
from src.cogponder import CogPonderModel
from src.cogponder.data import StroopSRODataset, CogPonderDataModule
from pathlib import Path

In [43]:
# Parameters

# number of maximum epochs to train (early stopping will be applied)
# early stopping patience is 10% of max_epochs (min 10 epochs)
max_epochs = 1000
batch_size = 96
step_duration = 10

In [44]:
# Data

print(f'Loading Stroop dataset... ', end='')

dataset = StroopSRODataset(n_subjects=1, step_duration=step_duration, non_decision_time='auto')

datamodule = CogPonderDataModule(dataset, batch_size=batch_size, num_workers=8)
datamodule.prepare_data()

# determine some parameters from data
n_subjects = datamodule.dataset[:][0].unique().size(0)
n_contexts = datamodule.dataset[:][2].unique().size(0)
n_features = datamodule.dataset[:][3].size(-1)
n_outputs = datamodule.dataset[:][6].unique().shape[0]
max_response_step = datamodule.dataset[:][5].max().int().item() + 1

configs = {
    'inputs_dim': n_features,
    'outputs_dim': n_outputs,
    'embeddings_dim': 2,
    'time_loss_beta': 1.,
    'learning_rate': 1e-2,
    'max_response_step': max_response_step,
    'n_contexts': n_contexts,
    'n_subjects': n_subjects,
    'subject_embeddings_dim': 2,
    'task': 'stroop',
    'operator_type': 'spatiotemporal'
}

print('Done!')

Loading Stroop dataset... Done!


In [None]:
# Experiment

model = CogPonderModel(**configs)

# model = torch.compile(model)

# Trainer
trainer = pl.Trainer(
    max_epochs=max_epochs,
    min_epochs=200,
    accelerator='auto',
    log_every_n_steps=1,
    # overfit_batches=True,
    # accumulate_grad_batches=2,
    callbacks=[
        EarlyStopping(monitor='val/total_loss',
                      patience=1000,
                      mode='min', min_delta=0.0001),
    ])

# Fit and evaluate the model
trainer.fit(model, datamodule=datamodule)

# save checkpoint
checkpoint_path = f'models/checkpoints/stroop/cogponder_epochs-{trainer.current_epoch}.ckpt'
trainer.save_checkpoint(checkpoint_path)

clear_output()

In [46]:
# DEBUG

model_ckpt = CogPonderModel.load_from_checkpoint(checkpoint_path)
print('Saved checkpoint to:', checkpoint_path)

Saved checkpoint to: models/checkpoints/stroop/cogponder_epochs-1000.ckpt
