#Imports

In [2]:
# Fixing the haiku problem
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Standard installs
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-acme[envs]
!pip install dm-env
!pip install dm-haiku
!pip install dm-tree
!pip install chex
!sudo apt-get install -y xvfb ffmpeg
!pip install imageio
!pip install gym
!pip install gym[classic_control]

!apt-get install -y patchelf

!apt-get install x11-utils
!pip install pyglet

!pip install gym pyvirtualdisplay

!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common
!pip install free-mujoco-py
!pip install imageio-ffmpeg

from IPython.display import clear_output
clear_output()

In [3]:
%matplotlib inline
import IPython
from IPython.display import HTML
from IPython import display as ipythondisplay

import acme
from acme import datasets
from acme import types
from acme import specs
from acme.wrappers import gym_wrapper
import base64
from base64 import b64encode
import chex
import collections
from collections import namedtuple
import dm_env
import enum
import functools
import gym
import mujoco_py
import haiku as hk
import imageio
import io
import itertools
import jax
from jax import tree_util
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import multiprocessing as mp
import multiprocessing.connection
import numpy as np
import pandas as pd
import random
import reverb
import rlax
import time
import tree
from typing import *
import warnings
import pyglet
pyglet.options['search_local_libs'] = False
pyglet.options['shadow_window']=False
from pyglet.window import xlib
xlib._have_utf8 = False

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
display.start()
 
np.set_printoptions(precision=3, suppress=1)

%matplotlib inline

# Utils

In [4]:
import abc

# Encapsulate a trajectory. Temporally, a trajectory unrolls as
# o_0, a_0, r_0, d_0, ..., o_{T-1}, a_{T-1}, r_{T-1}, d_{T-1}.
@chex.dataclass
class Trajectory:
  observations: types.NestedArray  # [T, B, ...]
  actions: types.NestedArray  # [T, B, ...]
  rewards: chex.ArrayNumpy  # [T, B]
  dones: chex.ArrayNumpy  # [T, B]
  discounts: chex.ArrayNumpy # [T, B]

# A very simple agent API, with just enough to interact with the environment
# and to update its potential parameters.
class Agent(abc.ABC):
  @abc.abstractmethod
  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    """One step of learning on a trajectory.
    
    The mapping returned can contain various logs.
    """
    pass

  @abc.abstractmethod
  def batched_actor_step(self, observations: types.NestedArray) -> types.NestedArray:
    """Returns actions in response to observations.
    
    Observations are assumed to be batched, i.e. they are typically arrays, or
    nests (think nested dictionaries) of arrays with shape (B, F_1, F_2, ...)
    where B is the batch size, and F_1, F_2, ... are feature dimensions.
    """
    pass


In [5]:
def simple_interaction_loop(agent: Agent, environment: dm_env.Environment, max_num_steps: int = 5000) -> None:
  ts = environment.reset()
  for _ in range(max_num_steps):
    if ts.last():
      break

    batched_observation = tree.map_structure(lambda x: x[None], ts.observation)
    action = agent.batched_actor_step(batched_observation)[0]  # batch size = 1
    ts = environment.step(action)

In [6]:
def display_video(frames, filename='temp.mp4', frame_repeat=1):
  """Save and display video."""
  # Write video
  with imageio.get_writer(filename, fps=60) as video:
    for frame in frames:
      for _ in range(frame_repeat):
        video.append_data(frame)
  # Read video and display the video
  video = open(filename, 'rb').read()
  b64_video = base64.b64encode(video)
  video_tag = ('<video  width="320" height="240" controls alt="test" '
               'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
  return IPython.display.HTML(video_tag)

# Environments

###Pendulum

In [13]:
class PendulumRandomAgent(Agent):
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self.rng = jax.random.PRNGKey(0)

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    self.rng, subkey = jax.random.split(self.rng)
    batch_size = len(observation)
    return [jax.random.normal(subkey,[batch_size])]

  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    return {"log": "Rien à signaler, je suis un random agent."}

In [7]:
class PendulumEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('Pendulum-v0')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(3,), minimum=-8., maximum=8., dtype=np.float32)

  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(1,), minimum=-2., maximum=2., dtype=np.float32)

  def close(self) -> None:
    self._env.close()

In [8]:
#pendulum_environment = PendulumEnv(for_evaluation=True)
#pendulum_environment_spec = acme.make_environment_spec(pendulum_environment)
#pendulum_random_agent = PendulumRandomAgent(pendulum_environment_spec)

#simple_interaction_loop(pendulum_random_agent, pendulum_environment, 5000)

In [9]:
#display_video(np.stack(pendulum_environment.screens, axis=0))

###CartPole

In [10]:
class CartPoleRandomAgent(Agent):
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self.rng = jax.random.PRNGKey(0)

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    self.rng, subkey = jax.random.split(self.rng)
    batch_size = len(observation)
    return [np.random.randint(0,2)]

  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    return {"log": "Rien à signaler, je suis un random agent."}

In [11]:
class CartPoleEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('CartPole-v0')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(4,), minimum=-np.inf, maximum=np.inf, dtype=np.float32)

  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(1,), minimum=0, maximum=1, dtype=np.int32)

  def close(self) -> None:
    self._env.close()

In [12]:
cartpole_environment = CartPoleEnv(for_evaluation=True)
cartpole_environment_spec = acme.make_environment_spec(cartpole_environment)
cartpole_random_agent = CartPoleRandomAgent(cartpole_environment_spec)

simple_interaction_loop(cartpole_random_agent, cartpole_environment, 5000)



In [13]:
display_video(np.stack(cartpole_environment.screens, axis=0))



### Inverted Pendulum

In [14]:
class InvertedPendulumRandomAgent(Agent):
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self.rng = jax.random.PRNGKey(0)

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    self.rng, subkey = jax.random.split(self.rng)
    batch_size = len(observation)
    return [jax.random.normal(subkey,[batch_size])]

  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    return {"log": "Rien à signaler, je suis un random agent."}

In [15]:
class InvertedPendulumEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('InvertedPendulum-v2')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(4,), minimum=-np.inf, maximum=np.inf, dtype=np.float32)

  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(1,), minimum=-3., maximum=3., dtype=np.float32)

  def close(self) -> None:
    self._env.close()

In [16]:
invertedpendulum_environment = InvertedPendulumEnv(for_evaluation=True)
invertedpendulum_environment_spec = acme.make_environment_spec(invertedpendulum_environment)
invertedpendulum_random_agent = InvertedPendulumRandomAgent(invertedpendulum_environment_spec)

simple_interaction_loop(invertedpendulum_random_agent, invertedpendulum_environment, 5000)

In [17]:
display_video(np.stack(invertedpendulum_environment.screens, axis=0))



### Reacher

In [18]:
class ReacherRandomAgent(Agent):
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self.rng = jax.random.PRNGKey(0)

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    self.rng, subkey = jax.random.split(self.rng)
    batch_size = len(observation)
    return [jax.random.normal(subkey,[batch_size,2])]

  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    return {"log": "Rien à signaler, je suis un random agent."}

In [19]:
class ReacherEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('Reacher-v2')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(11,), minimum=-np.inf, maximum=np.inf, dtype=np.float32)

  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(2,), minimum=-1., maximum=1., dtype=np.float32)

  def close(self) -> None:
    self._env.close()

In [20]:
reacher_environment = ReacherEnv(for_evaluation=True)
reacher_environment_spec = acme.make_environment_spec(reacher_environment)
reacher_random_agent = ReacherRandomAgent(reacher_environment_spec)

simple_interaction_loop(reacher_random_agent, reacher_environment, 5000)

In [21]:
display_video(np.stack(reacher_environment.screens, axis=0))

