In [5]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
sns.set()

In [8]:
from rlts.runners.eval_runner import EvalRunner
from rlts.meta.meta_policies.a_star_policy import AStarPolicy
from rlts.meta.meta_policies.random_policy import create_random_search_policy, create_random_search_policy_no_terminate
from rlts.meta.meta_policies.terminator_policy import TerminatorPolicy
from rlts.train.procgen_meta import create_batched_procgen_meta_envs, load_pretrained_q_network
from rlts.train.procgen_meta import reset_object_level_metrics, get_object_level_metrics


def create_test_batched_meta_envs(env_multithreading: bool = False,
                                  patch_terminates: bool = True,
                                  patch_expansions: bool = True,
                                  create_intermediate_search_policies: bool = False,
                                  compute_meta_rewards: bool = False):
    args = {
        'pretrained_runs_folder': 'runs',
        'pretrained_run': 'run-16823527592836354',
        'pretrained_percentile': 0.75,
        'expand_all_actions': True,
        'collect_steps_per_iteration': 16,
        'finish_on_terminate': True,
        'max_tree_size': 64,
    }

    object_config = load_pretrained_q_network(
        folder=args['pretrained_runs_folder'],
        run=args['pretrained_run'],
        percentile=args.get('pretrained_percentile', 0.5),
        verbose=False
    )

    batched_meta_env = create_batched_procgen_meta_envs(n_envs=64,
                                                        object_config=object_config,
                                                        env_multithreading=env_multithreading,
                                                        patch_terminates=patch_terminates,
                                                        patch_expansions=patch_expansions,
                                                        compute_meta_rewards=compute_meta_rewards,
                                                        create_intermediate_search_policies=create_intermediate_search_policies,
                                                        **args)
    return batched_meta_env

In [22]:
from cProfile import Profile
from time import time


def profile_random_steps(batched_meta_env, n_step: int = 50) -> Profile:
    batched_meta_env.reset()

    policy = create_random_search_policy_no_terminate(batched_meta_env)
    ts = batched_meta_env.current_time_step()
    policy_step = policy.action(ts)

    profile = Profile()
    profile.enable()

    for _ in range(n_step):
        ts = batched_meta_env.step(policy_step.action)
        policy_step = policy.action(ts)

    profile.disable()

    return profile


def time_random_steps(batched_meta_env, n_step: int = 50) -> Profile:
    batched_meta_env.reset()

    policy = create_random_search_policy_no_terminate(batched_meta_env)
    ts = batched_meta_env.current_time_step()
    policy_step = policy.action(ts)

    start_time = time()
    for _ in range(n_step):
        ts = batched_meta_env.step(policy_step.action)
        policy_step = policy.action(ts)
    time_taken = time() - start_time

    return time_taken

In [None]:
results = []

for _ in range(5):
    batched_meta_env = create_test_batched_meta_envs(
        env_multithreading=False,
        patch_terminates=True,
        patch_expansions=True,
        compute_meta_rewards=False,
        create_intermediate_search_policies=False,
    )
    time_taken = time_random_steps(batched_meta_env)
    results.append({
        'env_multithreading': False,
        'create_intermediate_search_policies': True,
        'patch_terminates': True,
        'patch_expansions': True,
        'compute_meta_rewards': False,
    })

In [14]:
import pstats

# patching terminates and expansions
batched_meta_env = create_test_batched_meta_envs()
profile_patch_all, time_taken = profile_random_steps(batched_meta_env)
pstats.Stats(profile_patch_all).sort_stats('cumulative').print_stats(10)

         8214126 function calls (8128271 primitive calls) in 30.808 seconds

   Ordered by: cumulative time
   List reduced from 807 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   140/50    0.003    0.000   30.442    0.609 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.014    0.000   27.743    0.617 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       45    0.050    0.001   19.725    0.438 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
    49472    9.413    0.000    9.426    0.000 /usr/local/lib/python3.8/dist-packages/gym3/libenv.py:383(call_c_func)
33640/18845    0.054    0.000    7.845    0.000 /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py:138(error_handler)
30665/18845    0.126    0.000    7.736    0.000 /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1162(op_dispatch_handler)
 

<pstats.Stats at 0x7f1db85f9d30>

In [15]:
import pstats

# patching terminates and expansions
batched_meta_env = create_test_batched_meta_envs(env_multithreading=True)
profile_patch_all = profile_random_steps(batched_meta_env)
pstats.Stats(profile_patch_all).sort_stats('cumulative').print_stats(10)

         2419413 function calls (2376317 primitive calls) in 42.423 seconds

   Ordered by: cumulative time
   List reduced from 739 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   140/50    0.003    0.000   41.774    0.835 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.012    0.000   39.508    0.878 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       45    0.073    0.002   19.974    0.444 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
       90    0.001    0.000   18.055    0.201 /usr/lib/python3.8/multiprocessing/pool.py:764(get)
       90    0.001    0.000   18.054    0.201 /usr/lib/python3.8/multiprocessing/pool.py:761(wait)
       90    0.001    0.000   18.054    0.201 /usr/lib/python3.8/threading.py:540(wait)
       90    0.002    0.000   18.052    0.201 /usr/lib/python3.8/threading.py:270(wait)
      363   18.050    0.050   18.050    0

<pstats.Stats at 0x7f1d9078be80>

In [11]:
import pstats

# patching terminates and expansions
batched_meta_env = create_test_batched_meta_envs(create_intermediate_search_policies=True)
profile_patch_all = profile_random_steps(batched_meta_env)
pstats.Stats(profile_patch_all).sort_stats('cumulative').print_stats(10)

         8206282 function calls (8120427 primitive calls) in 33.835 seconds

   Ordered by: cumulative time
   List reduced from 763 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   140/50    0.003    0.000   33.319    0.666 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.019    0.000   30.544    0.679 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       45    0.059    0.001   19.814    0.440 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
    49472    9.719    0.000    9.736    0.000 /usr/local/lib/python3.8/dist-packages/gym3/libenv.py:383(call_c_func)
33640/18845    0.064    0.000    8.138    0.000 /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py:138(error_handler)
30665/18845    0.151    0.000    8.010    0.000 /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1162(op_dispatch_handler)
 

<pstats.Stats at 0x7f1de044f250>

In [12]:
import pstats

# patching terminates and expansions
batched_meta_env = create_test_batched_meta_envs(compute_meta_rewards=True)
profile_patch_all = profile_random_steps(batched_meta_env)
pstats.Stats(profile_patch_all).sort_stats('cumulative').print_stats(10)

         24725931 function calls (23129409 primitive calls) in 70.125 seconds

   Ordered by: cumulative time
   List reduced from 771 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   150/50    0.004    0.000   68.735    1.375 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       50    0.044    0.001   68.724    1.374 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       50    0.156    0.003   33.895    0.678 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
       50    0.003    0.000   25.533    0.511 /tf/rlts/procgen/batched_procgen_meta_env.py:260(create_time_step)
       50    0.012    0.000   24.994    0.500 /tf/rlts/procgen/batched_procgen_meta_env.py:265(<listcomp>)
     3200    0.057    0.000   24.981    0.008 /tf/rlts/procgen/batched_procgen_meta_env.py:271(_create_time_step)
     3200    0.067    0.000   24.271    0.008 /tf/rlts/meta/meta_env.py:660(ob

<pstats.Stats at 0x7f1fd964e400>

In [4]:
import pstats

# patching terminates and expansions
batched_meta_env = create_test_batched_meta_envs()
profile_patch_all = profile_random_steps(batched_meta_env)
pstats.Stats(profile_patch_all).sort_stats('cumulative').print_stats(10)

         24486434 function calls (22970102 primitive calls) in 69.030 seconds

   Ordered by: cumulative time
   List reduced from 789 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   140/50    0.013    0.000   67.864    1.357 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.035    0.001   64.978    1.444 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       45    0.145    0.003   32.710    0.727 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
       45    0.004    0.000   25.396    0.564 /tf/rlts/procgen/batched_procgen_meta_env.py:260(create_time_step)
       45    0.012    0.000   24.854    0.552 /tf/rlts/procgen/batched_procgen_meta_env.py:265(<listcomp>)
     2880    0.054    0.000   24.842    0.009 /tf/rlts/procgen/batched_procgen_meta_env.py:271(_create_time_step)
     2880    0.068    0.000   24.133    0.008 /tf/rlts/meta/meta_env.py:660(ob

<pstats.Stats at 0x7fd53907dfa0>

In [5]:
import pstats

# patching just expansions
batched_meta_env = create_test_batched_meta_envs(patch_terminates=False)
profile_just_expansions = profile_random_steps(batched_meta_env)
pstats.Stats(profile_just_expansions).sort_stats('cumulative').print_stats(10)

         25350015 function calls (23842361 primitive calls) in 73.344 seconds

   Ordered by: cumulative time
   List reduced from 783 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   130/50    0.005    0.000   72.438    1.449 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.039    0.001   68.244    1.517 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
       40    0.075    0.002   28.342    0.709 /tf/rlts/procgen/batched_procgen_meta_env.py:219(handle_requests)
       45    0.002    0.000   24.657    0.548 /tf/rlts/procgen/batched_procgen_meta_env.py:260(create_time_step)
       45    0.012    0.000   24.266    0.539 /tf/rlts/procgen/batched_procgen_meta_env.py:265(<listcomp>)
     2880    0.062    0.000   24.253    0.008 /tf/rlts/procgen/batched_procgen_meta_env.py:271(_create_time_step)
     2880    0.070    0.000   23.509    0.008 /tf/rlts/meta/meta_env.py:660(ob

<pstats.Stats at 0x7fd4784f7af0>

In [6]:
import pstats

# patching just terminates
batched_meta_env = create_test_batched_meta_envs(patch_expansions=False)
profile_just_terminates = profile_random_steps(batched_meta_env)
pstats.Stats(profile_just_terminates).sort_stats('cumulative').print_stats(50)

         101461480 function calls (99273187 primitive calls) in 215.790 seconds

   Ordered by: cumulative time
   List reduced from 783 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    60/50    0.002    0.000  215.333    4.307 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       45    0.031    0.001  211.999    4.711 /tf/rlts/procgen/batched_procgen_meta_env.py:245(_step)
     2880    0.015    0.000  192.749    0.067 /tf/rlts/procgen/batched_procgen_meta_env.py:208(collect_expansion_requests)
     2880    0.090    0.000  192.715    0.067 /tf/rlts/meta/meta_env.py:617(act)
     2560    0.021    0.000  191.851    0.075 /tf/rlts/meta/meta_env.py:600(perform_computational_action)
     2560    0.082    0.000  191.477    0.075 /tf/rlts/meta/search_tree.py:520(expand_all)
    20480    0.143    0.000  190.806    0.009 /tf/rlts/meta/search_tree.py:534(expand_action)
    20480    0.113    0.000

<pstats.Stats at 0x7fd46b7b84f0>

In [10]:
from rlts.procgen.procgen_state import ProcgenProcessing

In [22]:
time_step = batched_meta_env.object_envs.reset()

profile2 = Profile()
profile2.enable()

for _ in range(10):
    _, q_vals = ProcgenProcessing.call(time_step.observation)
    time_step = batched_meta_env.object_envs.step(q_vals.argmax(axis=-1))

profile2.disable()

In [24]:
pstats.Stats(profile2).sort_stats('cumulative').print_stats(50)

         36779 function calls (36439 primitive calls) in 3.467 seconds

   Ordered by: cumulative time
   List reduced from 316 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    3.467    1.733 /usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py:3424(run_code)
        2    0.000    0.000    3.467    1.733 {built-in method builtins.exec}
        1    0.005    0.005    3.467    3.467 /tmp/ipykernel_13/90310075.py:6(<module>)
       10    0.001    0.000    2.615    0.262 /tf/rlts/procgen/procgen_state.py:31(call)
       10    0.013    0.001    2.580    0.258 /usr/local/lib/python3.8/dist-packages/tf_agents/networks/network.py:349(__call__)
  260/250    0.001    0.000    2.523    0.010 /usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py:138(error_handler)
      250    0.003    0.000    2.521    0.010 /usr/local/lib/python3.8/dist-packages/tensorflow/python/uti

<pstats.Stats at 0x7feb30221eb0>

In [4]:
from mlrl.procgen.procgen_state import ProcgenProcessing, ProcgenState

obs = batched_meta_env.envs[0]._gym_env.tree.root_node.state.observation

In [9]:
obs.shape

(1, 64, 64, 3)

In [13]:
%%timeit

for _ in range(512):
    ProcgenProcessing.call(obs)

1.49 s ± 27.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
np.concatenate([obs for _ in range(512)], axis=0).shape

(512, 64, 64, 3)

In [14]:
%%timeit

ProcgenProcessing.call(np.concatenate([obs for _ in range(512)], axis=0))

60.6 ms ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
print(f"That's a speed-up of {1.49 / 0.0606:.0f}x")

That's a speed-up of 25x


In [5]:
from mlrl.procgen.procgen_env import make_vectorised_procgen

In [6]:
n_expansions = len(ProcgenState.ACTIONS)
n_object_envs = batched_meta_env.batch_size * n_expansions

In [7]:
object_envs = make_vectorised_procgen(object_config, n_envs=n_object_envs)

In [8]:
from mlrl.experiments.procgen_meta import make_gym_procgen

object_envs_unbatched = [
    make_gym_procgen(object_config) for _ in range(n_object_envs)
]

In [10]:
states = [e._gym_env.tree.root_node.state.state[0] for e in batched_meta_env.envs for _ in range(n_expansions)]

In [80]:
action = 0

In [81]:
%%timeit

for e, state in zip(object_envs_unbatched, states):
    e.env.env.set_state([state])
    e.step(action)

155 ms ± 5.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
joint_action = np.array([0] * n_object_envs)

In [83]:
%%timeit

object_envs.env.set_state(states)
object_envs.step(joint_action)

86.9 ms ± 875 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [89]:
%%timeit

for e, state in zip(object_envs_unbatched, states):
    e.env.env.set_state([state])
    o, *_ = e.step(action)
    ProcgenProcessing.call(np.expand_dims(o, 0))

1.68 s ± 41.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit

for e, state in zip(object_envs_unbatched, states):
    e.env.env.set_state([state])
    e.step(action)

In [13]:
%%timeit

object_envs.env.set_state(states)
ts = object_envs.step(joint_action)
ProcgenProcessing.call(ts.observation)

141 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
print(f"That's a speed-up of {1.68 / 0.141:.0f}x")

That's a speed-up of 12x


In [100]:
%%timeit

obs = []
for e, state in zip(object_envs_unbatched, states):
    e.env.env.set_state([state])
    o, *_ = e.step(action)
    obs.append(o)


ProcgenProcessing.call(np.array(obs))

251 ms ± 8.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [102]:
print(f"That's a speed-up of {1.68 / 0.251:.0f}x")

That's a speed-up of 7x


In [104]:
42.312 / 12

3.526

In [107]:
64.094 / (64.094 - 42.312 + 3.526)

2.532558874664138

In [9]:
from mlrl.meta.meta_env import MetaEnv
from mlrl.meta.search_tree import SearchTreeNode
from mlrl.procgen.procgen_state import ProcgenProcessing, ProcgenState
from mlrl.procgen.procgen_env import make_vectorised_procgen

from typing import List, Tuple, Optional, Dict

from multiprocessing import pool
from multiprocessing import dummy as mp_threads
import numpy as np
import gym

import tensorflow as tf
from tf_agents.environments.gym_wrapper import spec_from_gym_space
from tf_agents.environments.py_environment import PyEnvironment
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import nest_utils


class NodeExpansionRequest:
    
    def __init__(self,
                 node: SearchTreeNode[ProcgenState],
                 child_node: SearchTreeNode[ProcgenState],
                 action: int):
        self.node = node
        self.action = action
        self.child_node = child_node

    def get_state(self) -> bytes:
        return self.node.state.state[0]


class BatchedProcgenMetaEnv(PyEnvironment):

    def __init__(self,
                 meta_envs: List[MetaEnv],
                 max_expansions: int,
                 object_config: dict,
                 spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None,
                 simplify_box_bounds: bool = True,
                 discount: types.Float = 0.99,
                 multithreading: bool = True,
                 auto_reset: bool = True):
        super(BatchedProcgenMetaEnv, self).__init__(auto_reset)

        self.meta_envs = meta_envs
        self.n_meta_envs = len(meta_envs)
        self.n_object_envs = max_expansions * len(meta_envs)
        self.object_config = object_config
        self.object_envs = make_vectorised_procgen(object_config,
                                                   n_envs=self.n_object_envs)

        env = meta_envs[0]    
        
        self._observation_spec = spec_from_gym_space(env.observation_space,
                                                     spec_dtype_map,
                                                     simplify_box_bounds,
                                                     'observation')

        self._action_spec = spec_from_gym_space(env.action_space,
                                                spec_dtype_map,
                                                simplify_box_bounds,
                                                'action')

        self.expansion_requests: List[NodeExpansionRequest] = []
        self.discount = discount
        self.multithreading = multithreading
        if multithreading:
            self._pool = mp_threads.Pool(self.n_meta_envs)

    @property
    def batched(self) -> bool:
        return self.n_meta_envs > 1

    @property
    def batch_size(self) -> int:
        return self.n_meta_envs

    def observation_spec(self) -> types.NestedArraySpec:
        return self._observation_spec

    def action_spec(self) -> types.NestedArraySpec:
        return self._action_spec

    def _reset(self):

        time_steps = [
            ts.restart(env.reset()) for env in self.meta_envs
        ]

        return nest_utils.stack_nested_arrays(time_steps)

    def patch_node(self, node: SearchTreeNode[ProcgenState]):

        def patched_create_child(_, object_action, new_node_id) -> SearchTreeNode:
            """
            Patched create_child method that creates a new node in the search tree
            without a populated state. The state variables will be set later
            when the expansion request is fulfilled.
            """
            new_node = SearchTreeNode(
                new_node_id, node, ProcgenState(), object_action,
                0, False, node.discount, node.q_function
            )

            self.expansion_requests.append(NodeExpansionRequest(node, new_node, object_action))

            return new_node

        node.create_child = patched_create_child

    def collect_expansion_requests(self, env: MetaEnv, meta_action: int):

        for node in env.tree.node_list:
            self.patch_node(node)

        env.act(meta_action)

    def handle_requests(self):

        states = self.object_envs.env.get_state()
        for i, request in enumerate(self.expansion_requests):
            states[i] = request.get_state()

        object_action = np.array([0] * self.n_object_envs)
        for i, request in enumerate(self.expansion_requests):
            object_action[i] = request.action

        self.object_envs.env.set_state(states)
        ts = self.object_envs.step(object_action)
        new_states = self.object_envs.env.get_state()

        n_requests = len(self.expansion_requests)
        state_vecs, q_values = ProcgenProcessing.call(ts.observation[:n_requests, ...])

        results = zip(self.expansion_requests,
                      state_vecs,
                      q_values,
                      new_states[:n_requests],
                      ts.observation,
                      ts.reward,
                      ts.is_last())

        for req, state_vec, q_value, new_state, obs, reward, done in results:
            req.child_node.reward = reward
            req.child_node.done = done
            req.child_node.state.set_variables([new_state], state_vec, obs, q_value)

    def _step(self, meta_actions: List[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
        self.expansion_requests.clear()

        def collect_expansion_requests(i, action):
            self.collect_expansion_requests(i, action)

        if self.multithreading:
            self._pool.starmap(self.collect_expansion_requests,
                               zip(self.meta_envs, meta_actions))
        else:
            for i, action in enumerate(meta_actions):
                self.get_expansion_requests(i, action)

        if self.expansion_requests:
            self.handle_requests()

        if self.multithreading:
            time_steps = self._pool.map(lambda e: self._create_time_step(e), self.meta_envs)
        else:
            time_steps = [
                self._create_time_step(env) for env in self.meta_envs
            ]
        return nest_utils.stack_nested_arrays(time_steps)

    def _create_time_step(self, env: MetaEnv):
        observation, reward, done, _ = env.observe()
        step_type = ts.StepType.LAST if done else ts.StepType.MID

        return ts.TimeStep(step_type=step_type,
                           reward=reward,
                           discount=self.discount,
                           observation=observation)

    def render(self, mode='rgb_array'):
        return np.vstack([env.render(mode=mode) for env in self.meta_envs])

In [4]:
meta_envs = [e._gym_env for e in batched_meta_env.envs]
batched_procgen_meta = BatchedProcgenMetaEnv(meta_envs, len(ProcgenState.ACTIONS), object_config)

In [32]:
time_step = batched_procgen_meta.step([1] * 64)

In [225]:
policy = create_random_search_policy_no_terminate(batched_procgen_meta)

In [33]:
time_step = batched_procgen_meta.current_time_step()
policy_step = policy.action(time_step)
batched_procgen_meta.step(policy_step.action.numpy());

In [228]:
%%timeit

time_step = batched_procgen_meta.current_time_step()
policy_step = policy.action(time_step)
batched_procgen_meta.step(policy_step.action.numpy());

1.18 s ± 266 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [229]:
batched_meta_env = create_batched_procgen_meta_envs(n_envs=64, object_config=object_config, env_multithreading=False, **args)
batched_meta_env.reset();

In [230]:
%%timeit

time_step = batched_meta_env.current_time_step()
policy_step = policy.action(time_step)
batched_meta_env.step(policy_step.action.numpy());

3.7 s ± 1.5 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [231]:
3.7 / 1.18

3.1355932203389836

In [34]:
from cProfile import Profile

policy = create_random_search_policy_no_terminate(batched_procgen_meta)
batched_procgen_meta.reset()

profile = Profile()
profile.enable()

for _ in range(10):
    time_step = batched_procgen_meta.current_time_step()
    policy_step = policy.action(time_step)
    batched_procgen_meta.step(policy_step.action)

profile.disable()


import pstats

pstats.Stats(profile).sort_stats('cumulative').print_stats()

         237095 function calls (230305 primitive calls) in 16.482 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000   16.482    8.241 /usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py:3361(run_code)
        2    0.000    0.000   16.482    8.241 {built-in method builtins.exec}
        1    0.000    0.000   16.482   16.482 /tmp/ipykernel_32976/1631964753.py:9(<cell line: 9>)
    30/10    0.002    0.000   16.197    1.620 /usr/local/lib/python3.8/dist-packages/tf_agents/environments/py_environment.py:198(step)
       10    0.002    0.000   16.194    1.619 /tmp/ipykernel_32976/3934244552.py:151(_step)
       20    0.000    0.000    9.225    0.461 /usr/lib/python3.8/multiprocessing/pool.py:764(get)
       20    0.000    0.000    9.224    0.461 /usr/lib/python3.8/multiprocessing/pool.py:761(wait)
       20    0.000    0.000    9.224    0.461 /usr/lib/python3.8/threading.py:540(wait

<pstats.Stats at 0x7f2a15f9cdf0>

In [14]:
# video_env = create_batched_procgen_meta_envs(n_envs=64, object_config=object_config, env_multithreading=False, **args)
from mlrl.experiments.procgen_meta import create_procgen_meta_env

n_envs = 2

meta_envs = [
    create_procgen_meta_env(
        object_config,
        min_computation_steps=5,
        **args
    )
    for _ in range(n_envs)
]

video_env = BatchedProcgenMetaEnv(meta_envs, len(ProcgenState.ACTIONS), object_config, multithreading=True)


policy = create_random_search_policy_no_terminate(video_env)


from mlrl.runners.eval_runner import EvalRunner


runner = EvalRunner(100, video_env, policy, video_env=video_env)

In [16]:
from mlrl.utils.render_utils import embed_mp4

embed_mp4(runner.create_policy_eval_video(60))