<a href="https://colab.research.google.com/github/gyyang/neurogym/blob/master/neurogym/examples/example_NeuroGym_stable_baselines.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NeuroGym

NeuroGym is a comprehensive toolkit that allows training any network model on many established neuroscience tasks using Reinforcement Learning techniques. It includes working memory tasks, value-based decision tasks and context-dependent perceptual categorization tasks.

In this notebook we first show how to install the relevant toolbox. 

We then show how to access the available tasks and their relevant information.

Finally we train an LSTM network on the Random Dots Motion task using standard supervised learning techniques (with Keras), and plot the results.

In [None]:
%tensorflow_version 1.x

# Install Gym and NeuroGym

In [None]:
! pip install gym

# Install NeuroGym

In [None]:
! git clone https://github.com/gyyang/neurogym.git

In [None]:
cd neurogym

In [None]:
pip install -e .

# Access tasks and wrappers info

In [None]:
cd /content/neurogym/

In [None]:
from neurogym.meta import tasks_info

### Get list of available tasks

In [None]:
tasks_info.info()

### Get information about specific task

In [None]:
tasks_info.info('RDM-v0', show_code=True, show_fig=True)

### Get list of available wrappers

In [None]:
tasks_info.info_wrapper()

### Get information about specific wrapper

In [None]:
tasks_info.info_wrapper('Monitor-v0', show_code=True)

In [None]:
tasks_info.info_wrapper('TrialHistory-v0')

# Example

In [1]:
from neurogym.examples import example_NeuroGym_keras as exK
# ARGS
task = 'RDM-v0'
num_trials = 100000
rollout = 20
dt = 100
kwargs = {'dt': 100, 'timing': {'fixation': ('constant', 200),
                                'stimulus': ('constant', 200),
                                'decision': ('constant', 200)}}

model = exK.train_env_keras_net(task, kwargs=kwargs,
                                rollout=rollout, num_tr=num_trials,
                                num_h=256, b_size=128,
                                tr_per_ep=1000, verbose=1)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 20, 3)             0         
_________________________________________________________________
lstm (LSTM)                  (None, 20, 256)           266240    
_________________________________________________________________
time_distributed (TimeDistri (None, 20, 3)             771       
Total params: 267,011
Trainable params: 267,011
Non-trainable params: 0
_________________________________________________________________
Instructions for updating:
Use tf.cast instead.
Accuracy:  0.6666667
Performance:  0.0
epoch 0 out of 100
remaining time: 0.05
-------------


<Figure size 640x480 with 3 Axes>

# Visualize results

In [None]:
data = tasks_info.plot_struct(env, num_steps_env=10000, n_stps_plt=200,
                                model=model)
