### Import dependencies

In [1]:
from pathlib import Path
from typing import Callable, Any

import hydra
from omegaconf import DictConfig, OmegaConf

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfigurationChannel,
    EngineConfig
)
from mlagents_envs.registry import UnityEnvRegistry, default_registry

from supersuit import observation_lambda_v0

import gym
from gym import Env
from gym_unity.envs import UnityToGymWrapper

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

### Load Environment, Engine and Model Configurations

In [2]:
GLOBAL_CONFIG = None
with hydra.initialize(config_path='config'):
    GLOBAL_CONFIG = hydra.compose(
        config_name='config.yaml',
        overrides=[
            'engine=infer',
            'environment=infer'
        ]
    )

# Assign configs
ENGINE_CONFIG = EngineConfig(**GLOBAL_CONFIG.engine)
ENV_CONFIG = GLOBAL_CONFIG.environment

# print(OmegaConf.to_yaml(GLOBAL_CONFIG))

### Unity Environment SB3

In [3]:
def _unity_env_from_path_or_registry(
    env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
    env_file_exists = Path(f'{env}.exe').expanduser().absolute().exists() or \
                      Path(f'{env}.x86_64').expanduser().absolute().exists() or \
                      Path(env).expanduser().absolute().exists()
    if env_file_exists:
        return UnityEnvironment(file_name=env, **kwargs)
    elif env in registry:
        return registry.get(env).make(**kwargs)
    else:
        raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
        
def make_mla_sb3_env(config: DictConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(ENGINE_CONFIG)
            # Unity environment parameters 
            environment_configuration_channel = EnvironmentParametersChannel()
            environment_configuration_channel.set_float_parameter('agent_level', config.agent_level)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel,
                environment_configuration_channel
            ]
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=default_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
            return new_env

        return _f

    env_facts = [
        create_env(config.env_path, worker_id=x) for x in range(config.num_env)
    ]
    return SubprocVecEnv(env_facts)

### Other Functions

In [4]:
def close_envs():
    try:
        env.close()
    except:
        pass
    finally:
        print("Closed environment")

### Start Environment & Load trained model

In [5]:
run_id = 'test_run_0'

close_envs()
env = make_mla_sb3_env(
    config=ENV_CONFIG,
    no_graphics=False,  # Set to false if you are running locally and want to watch the environments move around as they train.
)
print('Started Environment')

env = VecMonitor(env)

# Load model
model = PPO.load('TargetSeeker', env)

Closed environment
Started Environment


### Start Infering

In [None]:
n_eval_episodes = 10

try:
    if n_eval_episodes != 0:
        mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=n_eval_episodes)
        print(f'Mean Reward {mean_reward}')
except:
    pass
finally:
    close_envs()

In [23]:
n_steps = 100

obs = env.reset()
for i in range(n_steps):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)

In [8]:
close_envs()

Closed environment
