[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/ngym_usage/blob/master/supervised/auto_notebooks/rlGoNogo-v0.ipynb)

### Install packages if on Colab

In [1]:
# Uncomment following lines to install
# ! pip install gym   # Install gym
# ! git clone https://github.com/gyyang/neurogym.git  # Install neurogym
# %cd neurogym/
# ! pip install -e .

### Import packages

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 23 09:33:08 2020

@author: manuel
"""

import os
from pathlib import Path
import json
import importlib
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from neurogym.wrappers import ALL_WRAPPERS
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines.common import set_global_seeds
from stable_baselines.common.policies import LstmPolicy
from stable_baselines.common.callbacks import CheckpointCallback
import gym
import glob
import neurogym as ngym


envid = 'GoNogo-v0'

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [3]:
def get_modelpath(envid):
    # Make a local file directories
    path = Path('.') / 'files'
    os.makedirs(path, exist_ok=True)
    path = path / envid
    os.makedirs(path, exist_ok=True)
    return path


In [4]:
def apply_wrapper(env, wrap_string, params):
    wrap_str = ALL_WRAPPERS[wrap_string]
    wrap_module = importlib.import_module(wrap_str.split(":")[0])
    wrap_method = getattr(wrap_module, wrap_str.split(":")[1])
    return wrap_method(env, **params)


In [5]:
def make_env(env_id, rank, seed=0, wrapps={}, **kwargs):
    """
    Utility function for multiprocessed env.
    :param env_id: (str) the environment ID
    :param rank: (int) index of the subprocess
    :param seed: (int) the inital seed for RNG
    """
    def _init():
        env = gym.make(env_id, **kwargs)
        env.seed(seed + rank)
        for wrap in wrapps.keys():
            if not (wrap == 'MonitorExtended-v0' and rank != 0):
                env = apply_wrapper(env, wrap, wrapps[wrap])
        return env
    set_global_seeds(seed)
    return _init


In [6]:
def get_alg(alg):
    if alg == "A2C":
        from stable_baselines import A2C as algo
    elif alg == "ACER":
        from stable_baselines import ACER as algo
    elif alg == "ACKTR":
        from stable_baselines import ACKTR as algo
    elif alg == "PPO2":
        from stable_baselines import PPO2 as algo
    return algo


### Train network

In [7]:
"""Supervised training networks.

Save network in a path determined by environment ID.

Args:
    envid: str, environment ID.
"""

modelpath = get_modelpath(envid)
config = {
    'dt': 100,
    'hidden_size': 64,
    'lr': 1e-2,
    'alg': 'ACER',
    'rollout': 20,
    'n_thrds': 1,
    'wrappers_kwargs': {},
    'alg_kwargs': {},
    'seed': 0,
    # 'num_steps': 100000,
    'num_steps': 100,
    'envid': envid,
}

env_kwargs = {'dt': config['dt']}
config['env_kwargs'] = env_kwargs

# Save config
with open(modelpath / 'config.json', 'w') as f:
    json.dump(config, f)
algo = get_alg(config['alg'])
# Make supervised dataset
make_envs = [make_env(env_id=envid, rank=i, seed=config['seed'],
                              wrapps=config['wrappers_kwargs'],
                              **env_kwargs)
                     for i in range(config['n_thrds'])]
# env = SubprocVecEnv(make_envs)
env = DummyVecEnv(make_envs)  # Less efficient but more robust
model = algo(LstmPolicy, env, verbose=0, n_steps=config['rollout'],
             n_cpu_tf_sess=config['n_thrds'], tensorboard_log=None,
             policy_kwargs={"feature_extraction": "mlp",
                            "n_lstm": config['hidden_size']},
             **config['alg_kwargs'])
chckpnt_cllbck = CheckpointCallback(save_freq=int(config['num_steps']/10),
                                    save_path=modelpath,
                                    name_prefix='model')
model.learn(total_timesteps=config['num_steps'], callback=chckpnt_cllbck)
print('Finished Training')






Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Please use `layer.__call__` method instead.





Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Finished Training


In [8]:
def infer_test_timing(env):
    """Infer timing of environment for testing."""
    timing = {}
    for period in env.timing.keys():
        period_times = [env.sample_time(period) for _ in range(100)]
        timing[period] = np.median(period_times)
    return timing


In [9]:
def extend_obs(ob, num_threads):
    sh = ob.shape
    return np.concatenate((ob, np.zeros((num_threads-sh[0], sh[1]))))


In [10]:
def order_by_sufix(file_list):
    file_list = [os.path.basename(x) for x in file_list]
    flag = 'model.zip' in file_list
    file_list = [x for x in file_list if x != 'model.zip']
    sfx = [int(x[x.find('_')+1:x.rfind('_')]) for x in file_list]
    sorted_list = [x for _, x in sorted(zip(sfx, file_list))]
    if flag:
        sorted_list.append('model.zip')
    return sorted_list, np.max(sfx)


### Run network after training for analysis

In [14]:
"""Run trained networks for analysis.

Args:
    envid: str, Environment ID

Returns:
    activity: a list of activity matrices
    info: pandas dataframe, each row is information of a trial
    config: dict of network, training configurations
"""
modelpath = get_modelpath(envid)
files = glob.glob(str(modelpath)+'/model*')
if len(files) > 0:
    with open(modelpath / 'config.json') as f:
        config = json.load(f)
    env_kwargs = config['env_kwargs']
    wrappers_kwargs = config['wrappers_kwargs']
    seed = config['seed']
    # Run network to get activity and info
    sorted_models, last_model = order_by_sufix(files)
    model_name = sorted_models[-1]
    algo = get_alg(config['alg'])
    model = algo.load(modelpath / model_name, tensorboard_log=None,
                      custom_objects={'verbose': 0})

    # Environment
    env = make_env(env_id=envid, rank=0, seed=seed, wrapps=wrappers_kwargs,
                   **env_kwargs)()
    env.timing = infer_test_timing(env)
    env.reset(no_step=True)
    # Instantiate the network and print information
    activity = list()
    state_mat = []
    ob = env.reset()
    _states = None
    done = False
    info_df = pd.DataFrame()
    # num_steps = 10 ** 5
    num_steps = 10 ** 3
    for stp in range(int(num_steps)):
        ob = np.reshape(ob, (1, ob.shape[0]))
        done = [done] + [False for _ in range(config['n_thrds']-1)]
        action, _states = model.predict(extend_obs(ob, config['n_thrds']),
                                        state=_states, mask=done)
        action = action[0]
        ob, rew, done, info = env.step(action)
        if done:
            env.reset()
        if isinstance(info, (tuple, list)):
            info = info[0]
            action = action[0]
        state_mat.append(_states[0, int(_states.shape[1]/2):])
        if info['new_trial']:
            gt = env.gt_now
            correct = action == gt
            # Log trial info
            trial_info = env.trial
            trial_info.update({'correct': correct, 'choice': action})
            info_df = info_df.append(trial_info, ignore_index=True)
            # Log stimulus period activity
            activity.append(np.array(state_mat))
            state_mat = []
    env.close()
    
    activity = np.array(activity)
    



### General analysis

In [15]:
def analysis_average_activity(activity, info, config):
    # Load and preprocess results
    plt.figure(figsize=(1.2, 0.8))
    t_plot = np.arange(activity[0].shape[0]) * config['dt']
    plt.plot(t_plot, activity.mean(axis=0).mean(axis=-1))

analysis_average_activity(activity, info, config)

ValueError: operands could not be broadcast together with shapes (11,64) (13,64) 

<Figure size 86.4x57.6 with 0 Axes>

In [None]:
def get_conditions(info):
    """Get a list of task conditions to plot."""
    conditions = info.columns
    # This condition's unique value should be less than 5
    new_conditions = list()
    for c in conditions:
        try:
            n_cond = len(pd.unique(info[c]))
            if 1 < n_cond < 5:
                new_conditions.append(c)
        except TypeError:
            pass
        
    return new_conditions


In [None]:
def analysis_activity_by_condition(activity, info, config):
    conditions = get_conditions(info)
    for condition in conditions:
        values = pd.unique(info[condition])
        plt.figure(figsize=(1.2, 0.8))
        t_plot = np.arange(activity.shape[1]) * config['dt']
        for value in values:
            a = activity[info[condition] == value]
            plt.plot(t_plot, a.mean(axis=0).mean(axis=-1), label=str(value))
        plt.legend(title=condition, loc='center left', bbox_to_anchor=(1.0, 0.5))

analysis_activity_by_condition(activity, info, config)

In [None]:
def analysis_example_units_by_condition(activity, info, config):
    conditions = get_conditions(info)
    if len(conditions) < 1:
        return

    example_ids = np.array([0, 1])    
    for example_id in example_ids:        
        example_activity = activity[:, :, example_id]
        fig, axes = plt.subplots(
                len(conditions), 1,  figsize=(1.2, 0.8 * len(conditions)),
                sharex=True)
        for i, condition in enumerate(conditions):
            ax = axes[i]
            values = pd.unique(info[condition])
            t_plot = np.arange(activity.shape[1]) * config['dt']
            for value in values:
                a = example_activity[info[condition] == value]
                ax.plot(t_plot, a.mean(axis=0), label=str(value))
            ax.legend(title=condition, loc='center left', bbox_to_anchor=(1.0, 0.5))
            ax.set_ylabel('Activity')
            if i == len(conditions) - 1:
                ax.set_xlabel('Time (ms)')
            if i == 0:
                ax.set_title('Unit {:d}'.format(example_id + 1))

analysis_example_units_by_condition(activity, info, config)

In [None]:
def analysis_pca_by_condition(activity, info, config):
    # Reshape activity to (N_trial x N_time, N_neuron)
    activity_reshape = np.reshape(activity, (-1, activity.shape[-1]))
    pca = PCA(n_components=2)
    pca.fit(activity_reshape)
    
    conditions = get_conditions(info)
    for condition in conditions:
        values = pd.unique(info[condition])
        fig = plt.figure(figsize=(2.5, 2.5))
        ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])
        for value in values:
            # Get relevant trials, and average across them
            a = activity[info[condition] == value].mean(axis=0)
            a = pca.transform(a)  # (N_time, N_PC)
            plt.plot(a[:, 0], a[:, 1], label=str(value))
        plt.legend(title=condition, loc='center left', bbox_to_anchor=(1.0, 0.5))
    
        plt.xlabel('PC 1')
        plt.ylabel('PC 2')

analysis_pca_by_condition(activity, info, config)