### Import dependencies

In [55]:
import os
import shutil
import time
import numpy as np

from pathlib import Path
from typing import Callable, Any

import hydra
from omegaconf import DictConfig, OmegaConf

import gym
from gym import Env

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.callbacks import (
    BaseCallback,
    CheckpointCallback,
    CallbackList,
    EveryNTimesteps
)
from stable_baselines3.common.results_plotter import ts2xy, load_results

from supersuit import observation_lambda_v0

from mlagents_envs.environment import UnityEnvironment
from gym_unity.envs import UnityToGymWrapper
from mlagents_envs.registry import UnityEnvRegistry, default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfig,
    EngineConfigurationChannel,
)
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

### Load Environment, Engine and Model Configurations

In [59]:
GLOBAL_CONFIG = None
with hydra.initialize(config_path='config'):
    GLOBAL_CONFIG = hydra.compose(
        config_name='config.yaml',
        overrides=[]
    )

# Assign configs
ENGINE_CONFIG = EngineConfig(**GLOBAL_CONFIG.engine)
ENV_CONFIG = GLOBAL_CONFIG.environment
MODEL_CONFIG = GLOBAL_CONFIG.model

# print(OmegaConf.to_yaml(GLOBAL_CONFIG))

### Unity Environment SB3

In [3]:
def _unity_env_from_path_or_registry(
    env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
    env_file_exists = Path(f'{env}.exe').expanduser().absolute().exists() or \
                      Path(f'{env}.x86_64').expanduser().absolute().exists() or \
                      Path(env).expanduser().absolute().exists()
    if env_file_exists:
        return UnityEnvironment(file_name=env, **kwargs)
    elif env in registry:
        return registry.get(env).make(**kwargs)
    else:
        raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
        
def make_mla_sb3_env(config: DictConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(ENGINE_CONFIG)
            # Unity environment parameters 
            environment_configuration_channel = EnvironmentParametersChannel()
            environment_configuration_channel.set_float_parameter('agent_level', config.agent_level)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel,
                environment_configuration_channel
            ]
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=default_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
            return new_env

        return _f

    env_facts = [
        create_env(config.env_path, worker_id=x) for x in range(config.num_env)
    ]
    return SubprocVecEnv(env_facts)

### Custom Callbacks

In [4]:
class TensorboardCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super(TensorboardCallback, self).__init__(verbose)

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        value = np.random.random()
        self.logger.record('z_value', value)
        return True

class SummaryCallback(BaseCallback):
    """
    Summary callback for printing values on console.
    """

    def __init__(self, verbose=0, n_episodes=100, log_dir=None):
        super(SummaryCallback, self).__init__(verbose)
        self.n_episodes = n_episodes
        self.log_dir = log_dir

    def _on_training_start(self) -> None:
        print('Started Training...')
        self.start_time = time.time()

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        x, y = ts2xy(load_results(self.log_dir), 'timesteps')
        mean_reward = np.mean(y[-self.n_episodes:])
        std_reward = np.std(y[-self.n_episodes:])
        
        print(f'Step {self.num_timesteps}, \
        Time Elapsed {time.time()-self.start_time:.3f}s, \
        Mean Reward {mean_reward:.3f}, \
        Std Reward {std_reward:.3f}')
        
        return True

### Schedulers

In [10]:
def scheduler(initial_value: float, schedule: str) -> Callable[[float], float]:
    
    if schedule == 'linear':
        def func(progress_remaining: float) -> float:
            return progress_remaining * initial_value
    elif schedule == 'constant':
        def func(progress_remaining: float) -> float:
            return initial_value
    else:
        raise NameError(f'Invalid schedule: {schedule}!')
    return func


def check_valid_run_id(run_id, force=False):
    dir_path = f'results/{run_id}'
    if os.path.isdir(dir_path):
        if force:
            shutil.rmtree(dir_path)
        else:
            raise
    os.mkdir(dir_path)

def close_envs():
    try:
        env.close()
    except:
        pass
    finally:
        print("Closed environment")

def save_model():
    os.rename(f'results/{run_id}/TB_Logs_1', f'results/{run_id}/TB_Logs')
    model.save(f'results/{run_id}/TargetSeeker')
    print("Saved model")

### Start Environment

In [15]:
run_id = 'test_run_0'
force = True

close_envs()
env = make_mla_sb3_env(
    config=ENV_CONFIG,
    no_graphics=True,  # Set to false if you are running locally and want to watch the environments move around as they train.
)
print('Started Environment')

# Validate run-id
check_valid_run_id(run_id, force)

# Helps gather stats for our eval() calls later so we can see reward stats.
env = VecMonitor(env, filename=f'results/{run_id}/')

Closed environment
Started Environment


### Create the model

In [16]:
TOTAL_TAINING_STEPS_GOAL = MODEL_CONFIG.buffer_size * MODEL_CONFIG.updates
STEPS_PER_UPDATE = MODEL_CONFIG.buffer_size / ENV_CONFIG.num_env

#Policy and Value function with 2 layers of 128 units each and no shared layers.
policy_kwargs = {"net_arch" : [{
    "pi": list(MODEL_CONFIG.PI), 
    "vf": list(MODEL_CONFIG.VF)
}]}

model = PPO(
    "MlpPolicy",
    env,
    verbose=0,
    learning_rate=scheduler(**MODEL_CONFIG.learning_rate),
    clip_range=scheduler(**MODEL_CONFIG.clip_range),
    clip_range_vf=scheduler(**MODEL_CONFIG.clip_range_vf),
    tensorboard_log=f'results/{run_id}',
    policy_kwargs=policy_kwargs,
    n_steps=int(STEPS_PER_UPDATE),
    batch_size=MODEL_CONFIG.batch_size,
    n_epochs=MODEL_CONFIG.n_epochs,
    ent_coef=MODEL_CONFIG.beta,
)

# Callbacks
tb_callback = TensorboardCallback()
checkpoint_callback = CheckpointCallback(save_freq=max(int(ENV_CONFIG.check_point_freq/ENV_CONFIG.num_env),1), save_path=f'results/{run_id}/Checkpoints/',
                                         name_prefix='TargetSeeker')
summary_callback = EveryNTimesteps(n_steps=ENV_CONFIG.summary_freq, callback=SummaryCallback(n_episodes=100, log_dir=f'results/{run_id}'))

# Chain all callbacks
callbacks = CallbackList([tb_callback, checkpoint_callback, summary_callback])

### Train the model

In [None]:
# Start training
try:
    model.learn(total_timesteps=TOTAL_TAINING_STEPS_GOAL, reset_num_timesteps=True, callback=callbacks, tb_log_name='TB_Logs')
except:
    print('Training Interrupted!')
else:
    print('...Finished Training')    
finally:
    save_model()
    close_envs()


Started Training...
Step 10000,         Time Elapsed 7.106s,         Mean Reward -0.944,         Std Reward 0.475
Step 20000,         Time Elapsed 14.191s,         Mean Reward -1.043,         Std Reward 0.194
Step 30000,         Time Elapsed 21.423s,         Mean Reward -0.948,         Std Reward 0.478
Step 40000,         Time Elapsed 28.572s,         Mean Reward -0.981,         Std Reward 0.439
Training Interrupted!
Saved model


### Close the environment
Frees up the ports being used.

In [None]:
close_envs()