## Reinforcement learning example with stable-baselines

### Task

here we build the Random Dots Motion task, specifying the duration of each trial period (fixation, stimulus, decision) and wrapp it with the pass-reward wrapper which appends the previous reward to the observation. We then plot the structure of the task in a figure that shows: 
1. The observations received by the agent (top panel). 
2. The actions taken by a random agent and the correct action at each timestep (second panel).
3. The rewards provided by the environment at each timestep (third panel).
4. The performance of the agent at each trial (bottom panel).




In [None]:
import gymnasium as gym
import neurogym as ngym
from neurogym.wrappers import pass_reward
import matplotlib.pyplot as plt

# Task name
name = 'contrib.SequenceAlternation-v0'
# task specification (here we only specify the duration of the different trial periods)
rewards =  {'correct': +1., 'fail': 0.}

kwargs = {'dt': 100}# 'rewards': rewards, 'opponent_type': opponent_type, 'learning_rate': learning_rate}

# build task
env = gym.make(name, **kwargs)
# print task properties
print(env)

# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# plot example trials with random agent
data = ngym.utils.plot_env(
    env, fig_kwargs={'figsize': (12, 12)}, num_steps=100)

### Train a network

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C  # ACER, PPO2

# # Optional: PPO2 requires a vectorized environment to run
# # the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=1000)
env.close()

### Visualize results

In [None]:
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# env = DummyVecEnv([lambda: env])
# plot example trials with random agent
data = ngym.utils.plot_env(env, fig_kwargs={'figsize': (12, 12)}, num_steps=100, model=model)

---