<a href="https://colab.research.google.com/github/google/evojax/blob/main/examples/notebooks/TutorialNonVectorTask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial: Non-Vectorized Tasks
Sometimes you need to evolve policies for tasks that cannot be vectorized, i.e. can't be entirely written in jax and jitted to be able to run quickly on a GPU or TPU. This means that the task code will have to run on the CPU, and so the speed will be limited by the CPU spec as well as the TPU/GPU.
This tutorial gives an example of how to use EvoJAX to train policies to solve these kinds of tasks. This tutorial uses the CarRacing environment from OpenAIs Gym library. See https://www.gymlibrary.dev/.


In [None]:
!nvidia-smi

Wed Sep 14 06:59:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    58W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   39C    P0    67W / 400W |      0MiB / 40536MiB |     19%      Default |
|

## Prerequisites
We need to import the EvoJAX library, the Gym library which contains the CarRaceing task that we are solving, as well as libraries to allow us to record and visualise the trained policy.


In [None]:
from IPython.display import clear_output, Image

!pip install evojax

# Rendering Dependencies
!pip install pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
# Gym Dependencies
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[box2d]==0.16 > /dev/null 2>&1
clear_output()

In [None]:
import os
import numpy as np
import jax
import jax.numpy as jnp

import gym
from gym.wrappers import Monitor

from evojax.policy.mlp import MLPPolicy
from evojax.algo import PGPE
from evojax import Trainer
from evojax.util import create_logger


In [None]:
import logging
# Let's create a directory to save logs and models.
log_dir = './log'
logger = create_logger(name='EvoJAX', log_dir=log_dir, debug=True)
logger.setLevel(logging.INFO)
logger.info('Welcome to the tutorial on Task creation!')

logger.info('Jax backend: {}'.format(jax.local_devices()))
!nvidia-smi --query-gpu=name --format=csv,noheader

EvoJAX: 2022-09-14 06:59:34,142 [INFO] Welcome to the tutorial on Task creation!
absl: 2022-09-14 06:59:34,143 [DEBUG] Initializing backend 'interpreter'
absl: 2022-09-14 06:59:34,146 [DEBUG] Backend 'interpreter' initialized
absl: 2022-09-14 06:59:34,147 [DEBUG] Initializing backend 'cpu'
absl: 2022-09-14 06:59:34,151 [DEBUG] Backend 'cpu' initialized
absl: 2022-09-14 06:59:34,152 [DEBUG] Initializing backend 'tpu_driver'
absl: 2022-09-14 06:59:34,153 [INFO] Starting the local TPU driver.
absl: 2022-09-14 06:59:34,153 [INFO] Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
absl: 2022-09-14 06:59:34,154 [DEBUG] Initializing backend 'gpu'
absl: 2022-09-14 06:59:34,155 [INFO] Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
absl: 2022-09-14 06:59:34,155 [DEBUG] Initializing backend 'tpu'
absl: 2022-09-14 06:59:34,156 [INFO] Unabl

NVIDIA A100-SXM4-40GB
NVIDIA A100-SXM4-40GB


In [None]:
!lscpu

Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              24
On-line CPU(s) list: 0-23
Thread(s) per core:  2
Core(s) per socket:  12
Socket(s):           1
NUMA node(s):        1
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:            7
CPU MHz:             2200.220
BogoMIPS:            4400.44
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            39424K
NUMA node0 CPU(s):   0-23
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm a

We need to run a batch of policies against the task but we are only able to run a single policy on the Gym task so the following code is a Gym task wrapper that runs several Gym environments in parallel.


In [None]:
#@title VecEnv
from abc import ABC, abstractmethod
import pickle

import cloudpickle


class AlreadySteppingError(Exception):
    """
    Raised when an asynchronous step is running while
    step_async() is called again.
    """

    def __init__(self):
        msg = 'already running an async step'
        Exception.__init__(self, msg)


class NotSteppingError(Exception):
    """
    Raised when an asynchronous step is not running but
    step_wait() is called.
    """

    def __init__(self):
        msg = 'not running an async step'
        Exception.__init__(self, msg)


class VecEnv(ABC):
    """
    An abstract asynchronous, vectorized environment.

    :param num_envs: (int) the number of environments
    :param observation_space: (Gym Space) the observation space
    :param action_space: (Gym Space) the action space
    """

    def __init__(self, num_envs, observation_space, action_space):
        self.num_envs = num_envs
        self.observation_space = observation_space
        self.action_space = action_space

    @abstractmethod
    def reset(self, seeds):
        """
        Reset all the environments and return an array of
        observations, or a tuple of observation arrays.

        If step_async is still doing work, that work will
        be cancelled and step_wait() should not be called
        until step_async() is invoked again.

        :return: ([int] or [float]) observation
        """
        pass

    @abstractmethod
    def step_async(self, actions):
        """
        Tell all the environments to start taking a step
        with the given actions.
        Call step_wait() to get the results of the step.

        You should not call this if a step_async run is
        already pending.
        """
        pass

    @abstractmethod
    def step_wait(self):
        """
        Wait for the step taken with step_async().

        :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
        """
        pass

    @abstractmethod
    def close(self):
        """
        Clean up the environment's resources.
        """
        pass

    def step(self, actions):
        """
        Step the environments with the given action

        :param actions: ([int] or [float]) the action
        :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
        """
        self.step_async(actions)
        return self.step_wait()

    def get_images(self):
        """
        Return RGB images from each environment
        """
        raise NotImplementedError

    def render(self, *args, **kwargs):
        """
        Gym environment rendering

        :param mode: (str) the rendering type
        """
        raise NotImplementedError

    @property
    def unwrapped(self):
        if isinstance(self, VecEnvWrapper):
            return self.venv.unwrapped
        else:
            return self


class VecEnvWrapper(VecEnv):
    """
    Vectorized environment base class

    :param venv: (VecEnv) the vectorized environment to wrap
    :param observation_space: (Gym Space) the observation space (can be None to load from venv)
    :param action_space: (Gym Space) the action space (can be None to load from venv)
    """

    def __init__(self, venv, observation_space=None, action_space=None):
        self.venv = venv
        VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
                        action_space=action_space or venv.action_space)

    def step_async(self, actions):
        self.venv.step_async(actions)

    @abstractmethod
    def reset(self, seeds):
        pass

    @abstractmethod
    def step_wait(self):
        pass

    def close(self):
        return self.venv.close()

    def render(self, *args, **kwargs):
        return self.venv.render(*args, **kwargs)

    def get_images(self):
        return self.venv.get_images()


class CloudpickleWrapper(object):
    def __init__(self, var):
        """
        Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)

        :param var: (Any) the variable you wish to wrap for pickling with cloudpickle
        """
        self.var = var

    def __getstate__(self):
        return cloudpickle.dumps(self.var)

    def __setstate__(self, obs):
        self.var = pickle.loads(obs)

In [None]:
#@title SubprocVecEnv
from multiprocessing import Process, Pipe

import numpy as np


def _worker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.var()
    while True:
        try:
            cmd, data = remote.recv()
            if cmd == 'step':
                observation, reward, done, info = env.step(data)
                if done:
                    observation = env.reset()
                remote.send((observation, reward, done, info))
            elif cmd == 'reset':
                _ = env.seed(data)
                observation = env.reset()
                remote.send(observation)
            elif cmd == 'render':
                remote.send(env.render(*data[0], **data[1]))
            elif cmd == 'close':
                remote.close()
                break
            elif cmd == 'get_spaces':
                remote.send((env.observation_space, env.action_space))
            elif cmd == 'env_method':
                method = getattr(env, data[0])
                remote.send(method(*data[1], **data[2]))
            elif cmd == 'get_attr':
                remote.send(getattr(env, data))
            elif cmd == 'set_attr':
                remote.send(setattr(env, data[0], data[1]))
            else:
                raise NotImplementedError
        except EOFError as e:
            !echo $e >> log.txt
            break


def tile_images(img_nhwc):
    """
    Tile N images into one big PxQ image
    (P,Q) are chosen to be as close as possible, and if N
    is square, then P=Q.

    :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc
        n = batch index, h = height, w = width, c = channel
    :return: (numpy float) img_HWc, ndim=3
    """
    img_nhwc = np.asarray(img_nhwc)
    n_images, height, width, n_channels = img_nhwc.shape
    # new_height was named H before
    new_height = int(np.ceil(np.sqrt(n_images)))
    # new_width was named W before
    new_width = int(np.ceil(float(n_images) / new_height))
    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
    # img_HWhwc
    out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels)
    # img_HhWwc
    out_image = out_image.transpose(0, 2, 1, 3, 4)
    # img_Hh_Ww_c
    out_image = out_image.reshape(new_height * height, new_width * width, n_channels)
    return out_image


class SubprocVecEnv(VecEnv):
    """
    Creates a multiprocess vectorized wrapper for multiple environments

    :param env_fns: ([Gym Environment]) Environments to run in subprocesses
    """

    def __init__(self, env_fns):
        self.waiting = False
        self.closed = False
        n_envs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(n_envs)])
        self.processes = [Process(target=_worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                          for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for process in self.processes:
            process.daemon = True  # if the main process crashes, we should not cause things to hang
            process.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        observation_space, action_space = self.remotes[0].recv()
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self, seeds):
        for i, remote in enumerate(self.remotes):
            remote.send(('reset', int(seeds[i])))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for process in self.processes:
            process.join()
        self.closed = True

    def render(self, mode='human', *args, **kwargs):
        for pipe in self.remotes:
            # gather images from subprocesses
            # `mode` will be taken into account later
            pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
        imgs = [pipe.recv() for pipe in self.remotes]
        # Create a big image by tiling images from subprocesses
        bigimg = tile_images(imgs)
        if mode == 'human':
            import cv2
            cv2.imshow('vecenv', bigimg[:, :, ::-1])
            cv2.waitKey(1)
        elif mode == 'rgb_array':
            return bigimg
        else:
            raise NotImplementedError

    def get_images(self):
        for pipe in self.remotes:
            pipe.send(('render', {"mode": 'rgb_array'}))
        imgs = [pipe.recv() for pipe in self.remotes]
        return imgs

    def env_method(self, method_name, *method_args, **method_kwargs):
        """
        Provides an interface to call arbitrary class methods of vectorized environments

        :param method_name: (str) The name of the env class method to invoke
        :param method_args: (tuple) Any positional arguments to provide in the call
        :param method_kwargs: (dict) Any keyword arguments to provide in the call
        :return: (list) List of items retured by each environment's method call
        """

        for remote in self.remotes:
            remote.send(('env_method', (method_name, method_args, method_kwargs)))
        return [remote.recv() for remote in self.remotes]

    def get_attr(self, attr_name):
        """
        Provides a mechanism for getting class attribues from vectorized environments
        (note: attribute value returned must be picklable)

        :param attr_name: (str) The name of the attribute whose value to return
        :return: (list) List of values of 'attr_name' in all environments
        """

        for remote in self.remotes:
            remote.send(('get_attr', attr_name))
        return [remote.recv() for remote in self.remotes]

    def set_attr(self, attr_name, value, indices=None):
        """
        Provides a mechanism for setting arbitrary class attributes inside vectorized environments
        (note:  this is a broadcast of a single value to all instances)
        (note:  the value must be picklable)

        :param attr_name: (str) Name of attribute to assign new value
        :param value: (obj) Value to assign to 'attr_name'
        :param indices: (list,tuple) Iterable containing indices of envs whose attr to set
        :return: (list) in case env access methods might return something, they will be returned in a list
        """

        if indices is None:
            indices = range(len(self.remotes))
        elif isinstance(indices, int):
            indices = [indices]
        for remote in [self.remotes[i] for i in indices]:
            remote.send(('set_attr', (attr_name, value)))
        return [remote.recv() for remote in [self.remotes[i] for i in indices]]

In [None]:
#@title VecVideoRecorder
import os

from gym.wrappers.monitoring import video_recorder


class VecVideoRecorder(VecEnvWrapper):
    """
    Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
    It requires ffmpeg or avconv to be installed on the machine.

    :param venv: (VecEnv or VecEnvWrapper)
    :param video_folder: (str) Where to save videos
    :param record_video_trigger: (func) Function that defines when to start recording.
    The function takes the current number of step,
    and returns whether we should start recording or not.
    :param video_length: (int)  Length of recorded videos
    :param name_prefix: (str) Prefix to the video name
    """

    def __init__(self, venv, video_folder, record_video_trigger,
                 video_length=200, name_prefix='rl-video'):

        VecEnvWrapper.__init__(self, venv)

        self.env = venv
        # Temp variable to retrieve metadata
        temp_env = venv

        metadata = temp_env.get_attr('metadata')[0]

        self.env.metadata = metadata

        self.record_video_trigger = record_video_trigger
        self.video_recorder = None

        self.video_folder = os.path.abspath(video_folder)
        # Create output folder if needed
        os.makedirs(self.video_folder, exist_ok=True)

        self.name_prefix = name_prefix
        self.step_id = 0
        self.video_length = video_length

        self.recording = False
        self.recorded_frames = 0

    def reset(self, seed):
        obs = self.venv.reset(seed)
        self.start_video_recorder()
        return obs

    def start_video_recorder(self):
        self.close_video_recorder()

        video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
                                                    self.step_id + self.video_length)
        base_path = os.path.join(self.video_folder, video_name)
        self.video_recorder = video_recorder.VideoRecorder(
                env=self.env,
                base_path=base_path,
                metadata={'step_id': self.step_id}
                )

        self.video_recorder.capture_frame()
        self.recorded_frames = 1
        self.recording = True

    def _video_enabled(self):
        return self.record_video_trigger(self.step_id)

    def step_wait(self):
        obs, rews, dones, infos = self.venv.step_wait()

        self.step_id += 1
        if self.recording:
            self.video_recorder.capture_frame()
            self.recorded_frames += 1
            if self.recorded_frames > self.video_length:
                logger.info("Saving video to ", self.video_recorder.path)
                self.close_video_recorder()
        elif self._video_enabled():
            self.start_video_recorder()

        return obs, rews, dones, infos

    def close_video_recorder(self):
        if self.recording:
            self.video_recorder.close()
        self.recording = False
        self.recorded_frames = 1

    def close(self):
        VecEnvWrapper.close(self)
        self.close_video_recorder()

    def __del__(self):
        self.close()

We need to set up a virtual display since the Gym environment expects a screen to display the environment.

In [None]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

pyvirtualdisplay.abstractdisplay: 2022-09-15 02:10:20,035 [DEBUG] command: ['Xvfb', '-br', '-nolisten', 'tcp', '-screen', '0', '1400x900x24', '-displayfd', '69']
pyvirtualdisplay.abstractdisplay: 2022-09-15 02:10:20,097 [DEBUG] set $DISPLAY=:1


<pyvirtualdisplay.display.Display at 0x7ff8d535fef0>

# Task
This task simply wraps the Gym environment. The Gym environment is not written in Jax and so cannot be jitted with jax.jit, which is why this has to be a non-vectorized task. We use SubprocVecEnv here to run several gym environments in parallel.
Here we keep track of the number of steps where we get negative rewards and stop the simulation early if we get negative rewards for too many steps, this speeds up training significantly at the start.


In [None]:
env_id = 'CarRacing-v0'

In [None]:
def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [None]:
from evojax.task.base import VectorizedTask
from evojax.task.base import TaskState
from flax.struct import dataclass
import numpy as np

@dataclass
class State(TaskState):
    obs: jnp.ndarray          # (96, 96, 3) observed image.
    acc_reward: jnp.ndarray   # Tracks the number of steps we get a negative reward.
    steps: jnp.int32          # Tracks the rollout length.
    key: jnp.ndarray          # Random seed.


class Gym(VectorizedTask):

  def __init__(self, max_steps=1000, pop_size=-1, test=False):
    env = SubprocVecEnv([lambda: gym.make(env_id, verbose=False)]*pop_size)
    # For debugging you can use DummyVecEnv
    # env = DummyVecEnv([lambda: gym.make(env_id, verbose=False)]*pop_size)
    self.max_steps = max_steps
    self.obs_shape = env.observation_space.shape
    self.act_shape = env.action_space.shape
    self.env = env
    self.action_high = jnp.array([[1., 1., 1.]])
    self.action_low = jnp.array([[-1., 0., 0.]])
    self.test = test

  def reset(self, key):
    seeds = jax.random.randint(
        key[0], shape=(pop_size,), minval=0, maxval=1<<30)
    seeds = [int(x) for x in np.array(seeds, dtype=int)]
    obs = self.env.reset(seeds)
    return State(
        obs=jnp.array(obs)[:, None],
        acc_reward=jnp.array([0]*obs.shape[0]),
        steps=jnp.zeros((), dtype=int),
        key=key,
    )

  def step(self, state, action):
    action = (action * (self.action_high - self.action_low) / 2. +
        (self.action_high + self.action_low) / 2.)

    obs, reward, done, info = self.env.step(np.array(action))
    
    if not self.test:
      # If we get a positive reward reset the counter
      acc_reward = state.acc_reward * (reward < 0)
      # Add one to the accumulator if we get a negative reward.
      acc_reward += (reward < 0)
      # Stop if we get more than 20 consecutive negative rewards.
      done = done | (acc_reward > 20)

    return State(
        acc_reward=acc_reward,
        obs=obs[:, None],
        steps=state.steps+1,
        key=state.key), reward, done


# task = Gym(pop_size=32)

# Policy
The policy is a simple CNN, the final layer has 3 output neurons to represent the 3 actions: steering (-1 is full left, +1 is full right), gas, and breaking.


In [None]:
import logging
from typing import Tuple

import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn

from jax import tree_util

from evojax.policy.base import PolicyNetwork
from evojax.policy.base import PolicyState
from evojax.task.base import TaskState
from evojax.util import create_logger
from evojax.util import get_params_format_fn

class CNN(nn.Module):
    """CNN for learning to drive."""

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=8, kernel_size=(5, 5), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(5, 5), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1,))  # flatten
        x = nn.Dense(features=3)(x)
        x = nn.tanh(x)

        return x[0,:]


class ConvNetPolicy(PolicyNetwork):

    def __init__(self, logger: logging.Logger = None):
        if logger is None:
            self._logger = create_logger('ConvNetPolicy')
        else:
            self._logger = logger

        model = CNN()
        params = model.init(random.PRNGKey(0), jnp.zeros([1, 96, 96, 3]))
        self.init_params, _ = tree_util.tree_flatten(params)
        self.num_params, format_params_fn = get_params_format_fn(params)
        self._logger.info(
            'ConvNetPolicy.num_params = {}'.format(self.num_params))
        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(model.apply)

    def get_actions(self,
                    t_states: TaskState,
                    params: jnp.ndarray,
                    p_states: PolicyState) -> Tuple[jnp.ndarray, PolicyState]:
        params = self._format_params_fn(params)
        return self._forward_fn(params, t_states.obs), p_states

# Training
This code is the same as usual except the very important difference is that `use_for_loop=True` for the `Trainer`.



In [None]:
os.environ['XLA_FLAGS'] = 'xla_gpu_strict_conv_algorithm_picker=false'

seed = 123
pop_size = 64

task = Gym(pop_size=pop_size)

policy = ConvNetPolicy(logger=logger)
solver = PGPE(
    pop_size=pop_size,
    param_size=policy.num_params,
    optimizer='adam',
    center_learning_rate=0.05,
    seed=seed,
)
trainer = Trainer(
    policy=policy,
    solver=solver,
    train_task=task,
    test_task=task,
    max_iter=60,
    log_interval=1,
    test_interval=20,
    n_repeats=1,
    n_evaluations=pop_size,
    seed=seed,
    log_dir=log_dir,
    logger=logger,
    use_for_loop=True  # <<--- This must be set to True for unvectorized tasks.
)

absl: 2022-09-14 06:59:38,065 [DEBUG] Compiling _fold_in (140706710182072) for args (ShapedArray(uint32[2]), ShapedArray(uint32[])).
absl: 2022-09-14 06:59:38,180 [DEBUG] Compiling _truncated_normal (140706709994808) for args (ShapedArray(uint32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)).
absl: 2022-09-14 06:59:38,773 [DEBUG] Compiling _truncated_normal (140706710264568) for args (ShapedArray(uint32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)).
absl: 2022-09-14 06:59:39,165 [DEBUG] Compiling _truncated_normal (140706709869768) for args (ShapedArray(uint32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)).
EvoJAX: 2022-09-14 06:59:39,546 [INFO] ConvNetPolicy.num_params = 31475
absl: 2022-09-14 06:59:39,550 [DEBUG] Compiling init (140706709024056) for args (ShapedArray(float32[31475]),).
EvoJAX: 2022-09-14 06:59:39,564 [INFO] use_for_loop=True
absl: 2022-09-14 06:59:39,568 [DEBUG] Com

# Training
Since the task is not vectorized and runs on the CPU the training is slower, 100 steps takes about four hours to run.

In [None]:
_ = trainer.run()

EvoJAX: 2022-09-14 06:59:39,587 [INFO] Start to train for 80 iterations.
absl: 2022-09-14 06:59:39,601 [DEBUG] Compiling ask_func (140702004660552) for args (ShapedArray(uint32[2]), ShapedArray(float32[31475]), ShapedArray(float32[31475])).
absl: 2022-09-14 06:59:40,027 [DEBUG] Compiling duplicate_params (140706709612440) for args (ShapedArray(float32[64,31475]),).
absl: 2022-09-14 06:59:40,050 [DEBUG] Compiling get_task_reset_keys (140706710692128) for args (ShapedArray(uint32[2]),).
absl: 2022-09-14 06:59:40,253 [DEBUG] Compiling _randint (140706709168344) for args (ShapedArray(uint32[2]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)).
absl: 2022-09-14 06:59:43,149 [DEBUG] Compiling reset (140706708809696) for args ().
absl: 2022-09-14 06:59:43,587 [DEBUG] Compiling get_actions (140706709382200) for args (ShapedArray(uint8[64,1,96,96,3]), ShapedArray(float32[64,31475]), ShapedArray(uint32[64,2])).
absl: 2022-09-14 06:59:46,086 [DEBUG] Compiling update_sc

# Visualization

In [None]:
# Let's visualize the learned policy.

def render(algo, policy):
    """Render the learned policy."""

    test_task = Gym(pop_size=1, test=False)

    test_task.env = VecVideoRecorder(
        test_task.env, video_folder='videos/',
        record_video_trigger=lambda step: step == 0,
        video_length=1000,
        name_prefix='record')

    act_fn = jax.jit(policy.get_actions)
    policy_reset_fn = jax.jit(policy.reset)
    # Don't jit the task!
    task_reset_fn = test_task.reset
    step_fn = test_task.step

    params = algo.best_params[None, :]
    task_s = test_task.reset(jax.random.PRNGKey(seed=42)[None, :])
    policy_s = policy_reset_fn(task_s)

    done = False
    step = 0
    reward = 0
    while not done:
      act, policy_s = act_fn(task_s, params, policy_s)
      task_s, r, d = test_task.step(task_s, act)
      step += 1
      reward = reward + r
      done = bool(d[0])
    print('reward={}'.format(reward))
    print(step, 'steps')

    test_task.env.close()

render(solver, policy)



reward=[478.44054054]
621 steps


In [None]:
import base64
from pathlib import Path
from IPython import display as ipythondisplay

In [None]:
mp4 = Path("videos/record-step-0-to-step-1000.mp4")
video_b64 = base64.b64encode(mp4.read_bytes())
html = '''<video alt="{}" autoplay 
              loop controls style="height: 400px;">
              <source src="data:video/mp4;base64,{}" type="video/mp4" />
          </video>'''.format(mp4, video_b64.decode('ascii'))
ipythondisplay.display(ipythondisplay.HTML(html))