## Reinforcement learning example with stable-baselines

### Task

In [None]:
import gymnasium as gym
import neurogym as ngym
from neurogym.wrappers import pass_reward
import matplotlib.pyplot as plt

In [None]:
# Task name
name = 'contrib.SequenceAlternation-v0'
# task specification (here we only specify the duration of the different trial periods)
rewards =  {'correct': +1., 'fail': 0.}

kwargs = {'dt': 100, 'cued_epoch_periodicity': 3}  #'rewards': rewards, 'opponent_type': opponent_type, 'learning_rate': learning_rate}

# build task
env = gym.make(name, **kwargs)
# print task properties
print(env)

# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# plot example trials with random agent
data = ngym.utils.plot_env(
    env, fig_kwargs={'figsize': (12, 12)}, num_steps=1000)

### Train a network

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C  # ACER, PPO2
from sb3_contrib import RecurrentPPO

# # Optional: PPO2 requires a vectorized environment to run
# # the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = RecurrentPPO("MlpLstmPolicy", env, verbose=1)
# 30*10^6 steps
model.load("ppo2_sequencealternation")
model.learn(total_timesteps=3_000, log_interval=500_000)
# model.save("ppo2_sequencealternation")
env.close()

### Visualize results

In [None]:
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# env = DummyVecEnv([lambda: env])
# plot example trials with random agent
data = ngym.utils.plot_env(env, fig_kwargs={'figsize': (12, 12)}, num_steps=400, model=model)

In [None]:
seq_len = 10

# Make supervised dataset
dataset = ngym.Dataset('PerceptualDecisionMaking-v0', batch_size=16,
                       seq_len=seq_len)
env = dataset.env
ob_size = env.observation_space.shape[0]
act_size = env.action_space.n

In [None]:
a, b = next(dataset)
a.shape, b.shape

In [None]:
import numpy as np

def generate_sequence(num_range, sequence_length):
    numel = sequence_length // 2 + 1
    sequence = np.random.choice(num_range, size=numel, replace=False)
    sequence = np.insert(
        sequence, -1, values=sequence[: sequence_length - numel]
    )
    return sequence

generate_sequence(8, 8)

In [None]:
from itertools import permutations

def generate_unique_sequences(num_range, sequence_length):
    numel = sequence_length // 2 + 1
    uniqe_elements_perms = permutations(range(num_range), numel)
    unique_sequences = []
    for perm in uniqe_elements_perms:
        sequence = np.insert(
            perm, -1, values=perm[: sequence_length - numel])
        unique_sequences.append(sequence)
    return unique_sequences

num_range = 7
sequence_length = 8
dataset = generate_unique_sequences(num_range, sequence_length)
assert np.math.perm(num_range, sequence_length // 2 + 1) == len(dataset)

def generate_dataset(num_range, sequence_length, batch_size):
    """_summary_

    Parameters
    ----------
    num_range : _type_
        _description_
    sequence_length : _type_
        _description_
    batch_size : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    
    Example sequence
    ----------------
        x = [3, 2, 4, 1, 3, 2, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0]
        y = [3, 2, 4, 1, 3, 2, 4, 6, 3, 2, 4, 1, 3, 2, 4, 6]
    """
    unique_sequences = generate_unique_sequences(num_range, sequence_length)
    dataset = []
    for _ in range(batch_size):
        for sequence in unique_sequences:
            dataset.append(sequence)
    return dataset

---