### Import dependencies

In [1]:
import os

from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Any

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import (
    EngineConfigurationChannel,
    EngineConfig
)
from mlagents_envs.registry import UnityEnvRegistry, default_registry

from supersuit import observation_lambda_v0

import gym
from gym import Env
from gym_unity.envs import UnityToGymWrapper

from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv, DummyVecEnv
from mlagents_envs.side_channel.environment_parameters_channel import EnvironmentParametersChannel

### Environment  and Engine Configurations

In [2]:
INFER_ENGINE_CONFIG = EngineConfig(
    width=800,
    height=800,
    quality_level=4,
    time_scale=1,
    target_frame_rate=-1,
    capture_frame_rate=60,
)

# Some config subset of an actual config.yaml file for MLA.
@dataclass
class LimitedConfig:
    # The local path to a Unity executable or the name of an entry in the registry.
    env_path_or_name: str
    base_port: int
    base_seed: int = 0
    num_env: int = 1
    engine_config: EngineConfig = INFER_ENGINE_CONFIG
    visual_obs: bool = False
    allow_multiple_obs: bool = False
    env_registry: UnityEnvRegistry = default_registry
    agent_level: float = 4

### Unity Environment SB3

In [3]:
def _unity_env_from_path_or_registry(
    env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
    env_file_exists = Path(f'{env}.exe').expanduser().absolute().exists() or \
                      Path(f'{env}.x86_64').expanduser().absolute().exists() or \
                      Path(env).expanduser().absolute().exists()
    if env_file_exists:
        return UnityEnvironment(file_name=env, **kwargs)
    elif env in registry:
        return registry.get(env).make(**kwargs)
    else:
        raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
        
def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:
    def handle_obs(obs, space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return obs[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return {str(i): v for i, v in enumerate(obs)}
        return obs

    def handle_obs_space(space):
        if isinstance(space, gym.spaces.Tuple):
            if len(space) == 1:
                return space[0]
            # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
            return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
        return space

    def create_env(env: str, worker_id: int) -> Callable[[], Env]:
        def _f() -> Env:
            engine_configuration_channel = EngineConfigurationChannel()
            engine_configuration_channel.set_configuration(config.engine_config)
            # Unity environment parameters 
            environment_configuration_channel = EnvironmentParametersChannel()
            environment_configuration_channel.set_float_parameter('agent_level', config.agent_level)
            kwargs["side_channels"] = kwargs.get("side_channels", []) + [
                engine_configuration_channel,
                environment_configuration_channel
            ]
            unity_env = _unity_env_from_path_or_registry(
                env=env,
                registry=config.env_registry,
                worker_id=worker_id,
                base_port=config.base_port,
                seed=config.base_seed + worker_id,
                **kwargs,
            )
            new_env = UnityToGymWrapper(
                unity_env=unity_env,
                uint8_visual=config.visual_obs,
                allow_multiple_obs=config.allow_multiple_obs,
            )
            new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
            return new_env

        return _f

    env_facts = [
        create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)
    ]
    return SubprocVecEnv(env_facts)

### Start Environment & Load trained model

In [6]:
# -----------------
# This code is used to close an env that might not have been closed before
try:
  env.close()
except:
  pass
# -----------------

NUM_ENVS = 1
build_path = 'C:/main/MLAgents/TargetSeeker/Build/win/TargetSeeker'
agent_level = 4

config=LimitedConfig(
        env_path_or_name=build_path,  # Can use any name from a registry or a path to your own unity build.
        base_port=5005,
        base_seed=0,
        num_env=NUM_ENVS,
        allow_multiple_obs=True,
        agent_level=agent_level
    )

env = make_mla_sb3_env(config=config, no_graphics=False)

env = VecMonitor(env)

# Load model
model = PPO.load('TargetSeeker', env)

### Start Infering

In [7]:
n_eval_episodes = 10

if n_eval_episodes != 0:
    mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=n_eval_episodes)
    print(f'Mean Reward {mean_reward}')

Mean Reward 1.451761245727539


In [23]:
obs = env.reset()
for i in range(100):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    # env.render()

In [8]:
try:
    env.close()
    del model
except:
    pass
finally:
    print("Closed environment")

Closed environment
