# Data Collection

## Vectorise Environment

In [None]:
from crowd_sim.envs.crowd_sim_sgan import CrowdSimSgan
from crowd_sim.envs.crowd_sim_sgan_apf import CrowdSimSganApf
from crowd_sim.envs.crowd_sim_no_pred import CrowdSimNoPred
import gym
import time
import numpy as np

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecTransposeImage
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import PPO, A2C

In [None]:
from arguments import get_args
from crowd_nav.configs.config import Config

config = Config()

In [None]:
def make_env(seed, rank, env_config, envNum=1):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """

    def _init():
        env = CrowdSimSganApf()
        # use a seed for reproducibility
        # Important: use a different seed for each environment
        # otherwise they would generate the same experiences
        env.configure(env_config)
        env.seed(seed + rank)
        env.setup(seed=seed+rank, num_of_env=envNum)
        return env

    return _init

In [None]:
num_cpu = 1  # Number of processes to use
seed = 0

venv = SubprocVecEnv([make_env(seed, i, config, num_cpu) for i in range(num_cpu)])
venv = VecTransposeImage(venv)

In [None]:
# obs = venv.reset()
# obs.shape

## Collecting dataset

In [None]:
from imitation_learning import rollout
rng = np.random.default_rng()

In [None]:
rollouts = rollout.rollout(
    policy=None,
    venv=venv,
    sample_until=rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=rng,
    unwrap=False,
    exclude_infos=True,
    verbose=True
)

## Save and load dataset

In [None]:
from imitation.data import serialize
dataset_path = './train/dataset/8ppl_240mapsize_100eps'

In [None]:
# save dataset
serialize.save(dataset_path, rollouts)

In [None]:
# load dataset
rollouts = serialize.load(dataset_path)

# GAIL

In [None]:
from stable_baselines3 import PPO
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm

In [None]:
learner = PPO(
    env=venv,
    policy='CnnPolicy',
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)

In [None]:
reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)

In [None]:
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=64,
    gen_replay_buffer_capacity=1024,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True
)

In [None]:
gail_trainer.train(300000)

In [None]:
learner.save('./train/GAIL/model1')