In [1]:
# reload custom modules imported during runtime
# https://stackoverflow.com/questions/50339549/google-colab-reload-imported-modules

%load_ext autoreload
%autoreload 2

In [5]:
import os

if not os.path.exists('NMA-group1'):
    !git clone -b impl_surya https://github.com/JuliaY123/NMA-group1.git
%load '/content/NMA-group1/nback_env.py'

In [None]:
#@title Install dependencies
!pip install swig
!pip install gymnasium['all']

import gymnasium as gym
from gymnasium import spaces
import math
import numpy as np
from numpy.random import default_rng
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable


import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import clear_output

clear_output()


Collecting swig
  Downloading swig-4.1.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.8 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m1.1 MB/s[0m eta [36m0:00:02[0m[2K     [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/1.8 MB[0m [31m6.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: swig
Successfully installed swig-4.1.1
Collecting gymnasium[all]
  Downloading gymnasium-0.29.0-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.8/953.8 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium[all])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl 

In [None]:
# @title Define environment
# N-back environment

import numpy as np
import gymnasium as gym
from gymnasium import Env, spaces, utils

class NBack(Env):

    # Examples
    # N = 2
    # step_count =        [ 0  1   2  3  4  5  6 ]
    # sequence =          [ a  b   c  d  a  d  a ] (except these are usually digits between 0-9)
    # correct actions =   [ ~  ~   0  0  0  1  1 ]
    # actions =           [ ~  ~   1  0  0  1  0 ]
    # reward_class =      [ ~  ~  FP TN TN TP FN ]
    # reward =            [ ~  ~  -1  0  0  1 -1 ]
    # Rewards input is structured as (TP, TN, FP, FN) (positive being matches)

    def __init__(self, N=2, num_trials=25, num_targets=None, rewards=(1, 1, -1, -1), num_obs=5, seed=2023):

        self.N = N
        self.num_trials = num_trials
        self.episode_length = num_trials + self.N
        self.num_targets = num_targets
        self.rewards = rewards
        self.num_obs = num_obs
        self.num_actions = 2
        # super().reset(seed=seed)

        # Check that parameters are legal
        assert(len(rewards) == 4)
        assert(num_targets is None or num_targets <= num_trials)

        # Define rewards, observation space and action space
        self.reward_range = (min(rewards), max(rewards))    # Range of rewards based on inputs
        # self.observation_space = spaces.Tuple([spaces.Discrete(10) for i in range(self.num_obs)])     # Tuple num_obs long with 10 possibilities
        self.observation_space = spaces.Box(low=0, high=9, shape=(5, ))
        self.action_space = spaces.Discrete(self.num_actions)                        # 0 (No match) or 1 (Match)

    def reset(self, seed=None):

        # Seed RNG
        super().reset(seed=seed)

        # Generate sequence and correct actions
        self._generate_sequence()
        self._get_correct_actions()

        # Observation is first character
        self.step_count = 0

        # initialize
        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action):

        # Calculate reward
        if self.step_count >= self.N:
            if (self.correct_actions[self.step_count - self.N]): # Match
                reward = self.rewards[0] if action else self.rewards[3] # TP if matched else FN
            else: # No match
                reward = self.rewards[2] if action else self.rewards[1] # FP if matches else TN
        else:
            reward = 0

        # Return next character or None

        self.step_count += 1
        observation = self._get_obs()
        info = self._get_info()

        if self.step_count < self.episode_length:
            return observation, reward, False, info
        else:
            return observation, reward, True, info

    def _generate_sequence(self):

        # Generate sequence of length self.episode_length (with correct number of targets)
        while True:
            self.sequence = np.random.randint(0, 9, size=(self.episode_length))
            if not self.num_targets or sum(self._get_correct_actions()) == self.num_targets:
                break


    def _get_obs(self):

        if self.step_count < self.num_obs:
            window = self.sequence[:self.step_count + 1]
            observation = np.pad(window, (self.num_obs - self.step_count -1, 0), mode='constant', constant_values=(0))
        elif self.step_count == self.episode_length:
            window = self.sequence[self.step_count + 1 - self.num_obs : self.step_count + 1]
            observation = np.pad(window, (0,1), mode='constant', constant_values=(0))
        else:
            window = self.sequence[self.step_count + 1 - self.num_obs : self.step_count + 1]
            observation = window

        return observation

    def _get_correct_actions(self):
        self.correct_actions = np.array([int(self.sequence[i] == self.sequence[i + self.N]) for i in range(self.num_trials)])
        return self.correct_actions

    def _get_info(self):
        info = {
            'step_count': self.step_count,
            }
        return info



In [None]:
import gymnasium as gym
from gymnasium import register

register(
    id='NBack-v0',
    entry_point='nback_env:NBack',
)

env = gym.make(
    'NBack-v0',
    N = 3,
    num_trials=100,
    num_targets=10,
    rewards=(1, 0, 0, 0),
    num_obs=5,
    seed=2023
    )

observation, info = env.reset()
print(f"reset observation:\t{observation}\n")

done = False

while not done:
    action = env.action_space.sample()
    next_observation, reward, done, info = env.step(action)
    print(f"observation:\t{observation}")
    print(f"action:\t{action}")
    print(f"reward:\t{reward}")
    print(f"next_observation:\t{next_observation}")
    print(f"done:\t{done}")
    print(f"info:\t{info['step_count']}")
    print(f"\n")
    observation = next_observation

In [None]:
import random

class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, state, action, next_state, reward, done):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = (state, action, next_state, reward, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, next_states, rewards, dones = zip(*batch)
        return np.stack(states), actions, np.stack(next_states), rewards, dones

    def __len__(self):
        return len(self.memory)


[autoreload of nback_env failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 302, in update_class
    if update_generic(old_obj, new_obj): continue
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/usr/local/lib/python3.10/dist-packages/IPython/extensions/autoreload.py", line 266, in update_function
    setattr(old, name, getattr(new, name))
ValueError: __init__() requires a code object with 1 free vars, not 0


In [None]:
def train(model, memory, optimizer, criterion, batch_size, gamma):
    if len(memory) < batch_size:
        return
    states, actions, next_states, rewards, dones = memory.sample(batch_size)
    states = Variable(torch.FloatTensor(states))
    actions = Variable(torch.LongTensor(actions))
    next_states = Variable(torch.FloatTensor(next_states))
    rewards = Variable(torch.FloatTensor(rewards))
    dones = Variable(torch.FloatTensor(dones))

    q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q_values = model(next_states).max(1)[0]

    target_q_values = rewards + gamma * next_q_values * (1 - dones)

    loss = criterion(q_values, target_q_values.detach())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def exec_training(env, gamma=0.99, num_episodes = 100):

    # Create an instance of the DQN model
    input_size = env.observation_space.shape[0]
    output_size = env.action_space.n
    model = DQN(input_size, output_size)

    # Create an instance of the replay memory
    capacity = 1000
    memory = ReplayMemory(capacity)

    # Set hyperparameters
    batch_size = 64
    lr = 0.001
    # gamma = 0.99
    # num_episodes = 100

    # Set up the optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    rewards_list = []
    # Training loop
    for episode in range(num_episodes):
        state = env.reset()
        state = state[0]
        done = False
        total_reward = 0

        step_count = 1
        while not done:
            # Select an action using epsilon-greedy policy
            epsilon = max(0.01, 0.08 - 0.01 * episode)
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_values = model(torch.FloatTensor(state))
                    action = q_values.argmax().item()

            # Take the selected action and observe the next state and reward
            # next_state, reward, done, terminated, truncated = env.step(action)
            next_state, reward, done, info = env.step(action)

            # Store the transition in the replay memory
            memory.push(state, action, next_state, reward, done)

            # Move to the next state
            state = next_state
            total_reward += reward

            # Train the model
            train(model, memory, optimizer, criterion, batch_size, gamma)
            step_count += 1

        rewards_list.append(total_reward)

        # Print the total reward for the episode
        # print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    return rewards_list




In [None]:
# Create environment, random agent and test

import gymnasium as gym
from gymnasium import register

register(
    id='NBack-v0',
    entry_point='nback_env:NBack',
)

env = gym.make(
    'NBack-v0',
    N = 2,
    num_trials=100,
    num_targets=1,
    rewards=(1, 0, 0, 0),
    num_obs=,
    seed=42
    )

observation, info = env.reset()

re1 = exec_training(env=env, gamma=0.0, num_episodes = 50)
re2 = exec_training(env=env, gamma=0.5, num_episodes = 50)
re3 = exec_training(env=env, gamma=0.99, num_episodes = 50)


In [None]:
from torch.nn.modules import ReplicationPad3d
import plotly.graph_objects as go

# Create the data sequences

y1 = np.array(re1)  # Example y values for line 1
y2 = np.array(re2)  # Example y values for line 1
y3 = np.array(re3)  # Example y values for line 1
x = np.arange(y1.shape[0])  # Example x values

# Create the line traces
trace1 = go.Scatter(x=x, y=y1, mode='lines', name='Line 1')
trace2 = go.Scatter(x=x, y=y2, mode='lines', name='Line 2')
trace3 = go.Scatter(x=x, y=y3, mode='lines', name='Line 3')

# Create the layout
layout = go.Layout(
    title='Line Plot with Three Lines',
    xaxis=dict(title='X-axis'),
    yaxis=dict(title='Y-axis')
)

# Combine the traces and layout
data = [
    trace1,
    trace2,
    trace3
    ]
fig = go.Figure(data=data, layout=layout)

# Show the plot
fig.show()
