### Import dependencies

In [1]:
import os
import shutil
import time
import numpy as np

from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Any

import gym
from gym import Env

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import (
    BaseCallback,
    CheckpointCallback,
    CallbackList,
    EveryNTimesteps
)
from stable_baselines3.common.results_plotter import ts2xy, load_results

from supersuit import observation_lambda_v0

from mlagents_envs.environment import UnityEnvironment
from gym_unity.envs import UnityToGymWrapper
from mlagents_envs.registry import UnityEnvRegistry, default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfig,
    EngineConfigurationChannel,
)
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

### Environment  and Engine Configurations

In [2]:
# Default values from CLI (See cli_utils.py)
DEFAULT_ENGINE_CONFIG = EngineConfig(
    width=84,
    height=84,
    quality_level=4,
    time_scale=20,
    target_frame_rate=-1,
    capture_frame_rate=60,
)

# Some config subset of an actual config.yaml file for MLA.
@dataclass
class LimitedConfig:
    # The local path to a Unity executable or the name of an entry in the registry.
    env_path_or_name: str
    base_port: int
    base_seed: int = 0
    num_env: int = 1
    engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG
    visual_obs: bool = False
    # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.
    # WARNING: Make sure to use MultiInputPolicy if you turn this on.
    allow_multiple_obs: bool = False
    env_registry: UnityEnvRegistry = default_registry
    agent_level: float = 4

### Unity Environment SB3

In [3]:
def _unity_env_from_path_or_registry(
    env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
    env_file_exists = Path(f'{env}.exe').expanduser().absolute().exists() or \
                      Path(f'{env}.x86_64').expanduser().absolute().exists() or \
                      Path(env).expanduser().absolute().exists()
    if env_file_exists:
        return UnityEnvironment(file_name=env, **kwargs)
    elif env in registry:
        return registry.get(env).make(**kwargs)
    else:
        raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
        
def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(config.engine_config)
            # Unity environment parameters 
            environment_configuration_channel = EnvironmentParametersChannel()
            environment_configuration_channel.set_float_parameter('agent_level', config.agent_level)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel,
                environment_configuration_channel
            ]
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=config.env_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
            return new_env

        return _f

    env_facts = [
        create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)
    ]
    return SubprocVecEnv(env_facts)

### Custom Callbacks

In [4]:
class TensorboardCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super(TensorboardCallback, self).__init__(verbose)

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        value = np.random.random()
        self.logger.record('z_value', value)
        return True
    
    def _on_training_end(self) -> None:
        print('...Finished Training!')
        os.rename(f'results/{run_id}/TB_Logs_1', f'results/{run_id}/TB_Logs')

class SummaryCallback(BaseCallback):
    """
    Summary callback for printing values on console.
    """

    def __init__(self, verbose=0, n_episodes=100, log_dir=None):
        super(SummaryCallback, self).__init__(verbose)
        self.n_episodes = n_episodes
        self.log_dir = log_dir

    def _on_training_start(self) -> None:
        print('Started Training...')
        self.start_time = time.time()

    def _on_step(self) -> bool:
        # Log scalar value (here a random variable)
        x, y = ts2xy(load_results(self.log_dir), 'timesteps')
        mean_reward = np.mean(y[-self.n_episodes:])
        std_reward = np.std(y[-self.n_episodes:])
        
        print(f'Step {self.num_timesteps}, \
        Time Elapsed {time.time()-self.start_time:.3f}s, \
        Mean Reward {mean_reward:.3f}, \
        Std Reward {std_reward:.3f}')
        
        return True

### Schedulers

In [5]:
def scheduler(initial_value: float, schedule: str) -> Callable[[float], float]:
    
    if schedule == 'linear':
        def func(progress_remaining: float) -> float:
            return progress_remaining * initial_value
    elif schedule == 'constant':
        def func(progress_remaining: float) -> float:
            return initial_value
    else:
        raise NameError(f'Invalid schedule: {schedule}!')
    return func


def check_valid_run_id(run_id, force=False):
    dir_path = f'results/{run_id}'
    if os.path.isdir(dir_path):
        if force:
            shutil.rmtree(dir_path)
        else:
            raise
    os.mkdir(dir_path)

### Start Environment

In [6]:
# -----------------
# This code is used to close an env that might not have been closed before
try:
  env.close()
except:
  pass
# -----------------

NUM_ENVS = 8
build_path = 'C:/main/MLAgents/TargetSeeker/Build/win/TargetSeeker'
run_id = 'test_run_0'
force = True
summary_freq = 10_000
check_point_freq = 1_000_000
n_check_points = 5
agent_level = 0

config=LimitedConfig(
        env_path_or_name=build_path,  # Can use any name from a registry or a path to your own unity build.
        base_port=5005,
        base_seed=0,
        num_env=NUM_ENVS,
        allow_multiple_obs=True,
        agent_level=agent_level
    )

env = make_mla_sb3_env(
    config=config,
    no_graphics=True,  # Set to false if you are running locally and want to watch the environments move around as they train.
)

# Validate run-id
check_valid_run_id(run_id, force)

# Helps gather stats for our eval() calls later so we can see reward stats.
env = VecMonitor(env, filename=f'results/{run_id}/')

### Create the model

In [7]:
# 250K should train to a reward ~= 0.90 for the "Basic" environment.
# We set the value lower here to demonstrate just a small amount of trianing.
BATCH_SIZE = 1024
BUFFER_SIZE = 10240
UPDATES = 2000
TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES
BETA = 0.0005
N_EPOCHS = 3
STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS

#Policy and Value function with 2 layers of 128 units each and no shared layers.
policy_kwargs = {"net_arch" : [{"pi": [128,128], "vf": [128,128]}]}

model = PPO(
    "MlpPolicy",
    env,
    verbose=0,
    learning_rate=scheduler(0.0003, 'linear'),
    clip_range=scheduler(0.2, 'linear'),
    clip_range_vf=scheduler(0.2, 'linear'),
    tensorboard_log=f'results/{run_id}',
    policy_kwargs=policy_kwargs,
    n_steps=int(STEPS_PER_UPDATE),
    batch_size=BATCH_SIZE,
    n_epochs=N_EPOCHS,
    ent_coef=BETA,
)

### Train the model

In [8]:
tb_callback = TensorboardCallback()
checkpoint_callback = CheckpointCallback(save_freq=max(int(check_point_freq/NUM_ENVS),1), save_path=f'results/{run_id}/Checkpoints/',
                                         name_prefix='TargetSeeker')
summary_callback = EveryNTimesteps(n_steps=summary_freq, callback=SummaryCallback(n_episodes=100, log_dir=f'results/{run_id}'))

# Chain all callbacks
callback = CallbackList([tb_callback, checkpoint_callback, summary_callback])

# Start train
try:
    model.learn(total_timesteps=TOTAL_TAINING_STEPS_GOAL, reset_num_timesteps=True, callback=callback, tb_log_name='TB_Logs')
except:
    pass
finally:
    model.save(f'results/{run_id}/TargetSeeker')
    del model
    print("Saved model")
    env.close()
    print("Closed environment")
    
# for i in range(UPDATES):
    # print(f'\rTraining round {i + 1}/{UPDATES}', end='')
    # model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0), callback=callback, tb_log_name='TB_Logs')
    # model.policy.eval()

Started Training...
Step 10000,         Time Elapsed 7.297s,         Mean Reward -0.984,         Std Reward 0.387
Step 20000,         Time Elapsed 14.405s,         Mean Reward -1.041,         Std Reward 0.212
Step 30000,         Time Elapsed 21.565s,         Mean Reward -1.004,         Std Reward 0.319
Step 40000,         Time Elapsed 28.926s,         Mean Reward -1.007,         Std Reward 0.343
Step 50000,         Time Elapsed 36.154s,         Mean Reward -0.943,         Std Reward 0.507
Step 60000,         Time Elapsed 43.395s,         Mean Reward -0.969,         Std Reward 0.480



KeyboardInterrupt



### Close the environment
Frees up the ports being used.

In [None]:
try:
    env.close()
    del model
except:
    pass
finally:
    print("Closed environment")