# Investigating RNNs and RL using the N-back cognitive task

**NMA 2023 Group 1 Project**

__Content creators:__ Alan Astudilo, Campbell Border, Disheng, Julia Yin, Koffivi

__Pod TA:__ Suryanarayanan Nagar Anthel Venkatesh

__Project Mentor:__ 

---
# Objective

- 

- 
---

# Project Design
---

# Setup

## Install Dependencies

In [1]:
# @title Install dependencies
!pip install jedi --quiet
!pip install --upgrade pip setuptools wheel --quiet
!pip install numpy==1.23.3 --quiet --ignore-installed
!pip install gymnasium --quiet


#!pip install dm-acme[jax] --quiet
#!pip install dm-sonnet --quiet
#!pip install trfl --quiet
#!pip uninstall seaborn -y --quiet
#!pip install seaborn --quiet

[31mERROR: Ignored the following versions that require a different python version: 1.22.0 Requires-Python >=3.8; 1.22.1 Requires-Python >=3.8; 1.22.2 Requires-Python >=3.8; 1.22.3 Requires-Python >=3.8; 1.22.4 Requires-Python >=3.8; 1.23.0 Requires-Python >=3.8; 1.23.0rc1 Requires-Python >=3.8; 1.23.0rc2 Requires-Python >=3.8; 1.23.0rc3 Requires-Python >=3.8; 1.23.1 Requires-Python >=3.8; 1.23.2 Requires-Python >=3.8; 1.23.3 Requires-Python >=3.8; 1.23.4 Requires-Python >=3.8; 1.23.5 Requires-Python >=3.8; 1.24.0 Requires-Python >=3.8; 1.24.0rc1 Requires-Python >=3.8; 1.24.0rc2 Requires-Python >=3.8; 1.24.1 Requires-Python >=3.8; 1.24.2 Requires-Python >=3.8; 1.24.3 Requires-Python >=3.8; 1.24.4 Requires-Python >=3.8; 1.25.0 Requires-Python >=3.9; 1.25.0rc1 Requires-Python >=3.9; 1.25.1 Requires-Python >=3.9[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement numpy==1.23.3 (from versions: 1.3.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, 1.6.1, 1.6.2, 1.7.0, 1.7.1, 1.7.2

In [None]:
# @title Imports

#import time
import numpy as np
#import pandas as pd
#import matplotlib.pyplot as plt
#import seaborn as sns
import gymnasium as gym
from gymnasium import spaces

## Figure settings

In [None]:
# @title Figure settings
from IPython.display import clear_output, display, HTML
%matplotlib inline
# sns.set()

---
# Background

## Replace with our own literature review

- Cognitive scientists use standard lab tests to tap into specific processes in the brain and behavior. Some examples of those tests are Stroop, N-back, Digit Span, TMT (Trail making tests), and WCST (Wisconsin Card Sorting Tests).

- Despite an extensive body of research that explains human performance using descriptive what-models, we still need a more sophisticated approach to gain a better understanding of the underlying processes (i.e., a how-model).

- Interestingly, many of such tests can be thought of as a continuous stream of stimuli and corresponding actions, that is in consonant with the RL formulation. In fact, RL itself is in part motivated by how the brain enables goal-directed behaviors using reward systems, making it a good choice to explain human performance.

- One behavioral test example would be the N-back task.

  - In the N-back, participants view a sequence of stimuli, one by one, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedback is given at both timestep and trajectory levels.

  - The agent is rewarded when its response matches the stimulus that was shown N steps back in the episode. A simpler version of the N-back uses two-choice action schema, that is match vs non-match. Once the present stimulus matches the one presented N step back, then the agent is expected to respond to it as being a `match`.


- Given a trained RL agent, we then find correlates of its fitted parameters with the brain mechanisms. The most straightforward composition could be the correlation of model parameters with the brain activities.

## Datasets

- HCP WM task ([NMA-CN HCP notebooks](https://github.com/NeuromatchAcademy/course-content/tree/master/projects/fMRI))

Any dataset that used cognitive tests would work.
Question: limit to behavioral data vs fMRI?
Question: Which stimuli and actions to use?
classic tests can be modeled using 1) bounded symbolic stimuli/actions (e.g., A, B, C), but more sophisticated one would require texts or images (e.g., face vs neutral images in social stroop dataset)
The HCP dataset from NMA-CN contains behavioral and imaging data for 7 cognitive tests including various versions of N-back.

### We need to copy the data and see what's happening

## N-back task

In the N-back task, participants view a sequence of stimuli, one per time, and are asked to categorize each stimulus as being either match or non-match. Stimuli are usually numbers, and feedbacks are given at both timestep and trajectory levels.

In a typical neuro setup, both accuracy and response time are measured, but here, for the sake of brevity, we focus only on accuracy of responses.

---
## Implementation scheme

### Environment
### I am just copying what's on the template to here
The following cell implments N-back envinronment, that we later use to train a RL agent on human data. It is capable of performing two kinds of simulation:
- rewards the agent once the action was correct (i.e., a normative model of the environment).
- receives human data (or mock data if you prefer), and returns what participants performed as the observation. This is more useful for preference-based RL.

In [None]:
# @title Define environment

# N-back environment
class NBack(gym.Env):

    # N = 2
    # step_count =        [ 0  1  2  3  4  5  6 ]
    # sequence =          [ a  b  c  d  a  d  a ]
    # correct actions =   [ ~  ~  0  0  0  1  1 ]

    # actions =           [ ~  ~  1  0  0  1  0 ]
    # reward_class =      [ ~  ~  FP TN TN TP FN]
    # reward =            [ ~  ~  -1  0  0  1 -1]

  # Rewards input is structured as (TP, TN, FP, FN) (positive being matches)
  def __init__(self, N=2, episode_length=16, chars="digits", rewards=(1, 0, -1, -1)):

    self.N = N
    self.episode_length = episode_length
    self.chars = chars
    self.rewards = rewards

    # Check that parameters are legal
    assert(chars=="digits" or chars=="letters")
    #assert(episode_length >= 2 and episode_length <= 32)

    # Define rewards, observation space and action space
    self.reward_range = (min(rewards), max(rewards))                                            # Range of rewards based on inputs
    self.observation_space = spaces.Discrete(10) if chars == "digits" else spaces.Discrete(26)  # Single variable with 10 possibilities if using digits or 26 if using letters
    self.action_space = spaces.Discrete(2)                                                      # 0 (No match) or 1 (Match)

  def step(self, action):

    # Get reward
    if self.step_count >= self.N:
      if (self.correct_actions[self.step_count]): # Match
        reward = self.rewards[0] if action else self.rewards[3] # TP if matched else FN
      else: # No match
        reward = self.rewards[2] if action else self.rewards[1] # FP if matches else TN
    else:
      reward = None

    # Get next character in sequence (or end episode)
    self.step_count += 1
    if self.step_count < self.episode_length:
      observation = self.sequence[self.step_count]
      done = False
    else:
      observation = None
      done = True
    info = "Not sure"

    return observation, reward, done, info

  def reset(self, seed=None):

    # Seed RNG
    super().reset(seed=seed)

    # Generate sequence of length self.episode_length
    if self.chars == "digits":
      self.sequence = np.random.randint(0, 9, size=(self.episode_length))
    else: 
      self.sequence = np.random.

    # Generate correct sequence of actions
    self.correct_actions = [None] * self.N + [int(self.sequence[i] == self.sequence[i + self.N]) for i in range(self.episode_length - self.N)]

    # Observation is first character
    self.step_count = 0
    observation = self.sequence[self.step_count]
    info = "Not sure"

    return observation, info



In [None]:
# @title Test environment

env = NBack()
print(f"First char is {env.reset()[0]}")
print(env.sequence)
print(env.correct_actions)
for i in range(16):
  print(env.step(i % 2))


### Define a random agent

In [None]:
## please write your code for random agents here

# Random agent
class RandomAgent(nn.Module):
    
    

### Define a simple Q-learning agent

### Define a Recurrent Deep Q-learning Agent (RDQN)

# Section 4: Model(s)

In [2]:
# Random agent
class RandomAgent(nn.Module):
    
    

SyntaxError: unexpected EOF while parsing (3220473413.py, line 4)

In [None]:
# RNN

class LayeredRNN(nn.Module)