# Installations

In [None]:
#install necessary libraries
!pip install wandb
!pip install dm_env
!pip install git+https://github.com/deepmind/dm-haiku
!pip install git+git://github.com/deepmind/bsuite.git
!pip install git+git://github.com/deepmind/optax.git
!pip install git+git://github.com/deepmind/rlax.git
!pip install dm-tree
!pip install packaging
!pip install tensorflow-datasets
!pip install tensorflow
!pip install pyvirtualdisplay
!apt-get install xvfb
!pip install git+https://github.com/tensorflow/docs
!pip install tqdm 
!pip install chex

Collecting wandb
  Downloading wandb-0.12.0-py2.py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 14.8 MB/s 
[?25hCollecting GitPython>=1.0.0
  Downloading GitPython-3.1.18-py3-none-any.whl (170 kB)
[K     |████████████████████████████████| 170 kB 67.0 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.3.1-py2.py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 68.7 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting subprocess32>=3.5.3
  Downloading subprocess32-3.5.4.tar.gz (97 kB)
[K     |████████████████████████████████| 97 kB 7.5 MB/s 
Collecting configparser>=3.8.1
  Downloading configparser-5.0.2-py3-none-any.whl (19 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.1-py3-none-any.whl (7.5 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.7-py3-none-any.whl (63 kB)
[K     |

In [None]:
#import libraries
import jax
import collections
import functools
from typing import Any, Callable, Optional, Tuple, Dict
import dm_env
import haiku as hk
from examples.impala import util
import jax.numpy as jnp
import numpy as np
import chex
import itertools
import queue
import threading
import warnings
from examples.impala import util
from jax.experimental import optimizers
import optax
import rlax
from tqdm import tqdm  
from rlax._src import base
import haiku as hk
from examples.impala import util

from PIL import Image
import threading
from typing import List
import wandb
from absl import app
from bsuite.environments import cartpole
from examples.impala import util
import jax
import optax
import collections
import gym
import numpy as np
import statistics
import tensorflow as tf
import re
import matplotlib.pyplot as plt
import numpy as np
import collections

import dm_env
import haiku as hk
import jax.nn
import jax.numpy as jnp

# Agent


In [None]:
#create var for NN output
AgentOutput = collections.namedtuple("AgentOutput",
                                     ["policy_logits", "values", "action"])

Action = int
Nest = Any
NetFactory = Callable[[int], hk.RNNCore]

#agent interface
class Agent:
  def __init__(self, num_actions: int, obs_spec: Nest,
               net_factory: NetFactory):
    """
    The interface for the agent.
    Args:
      num_actions: Number of possible actions for the agent. Assumes a flat,
        discrete, 0-indexed action space.
      obs_spec: The observation spec of the environment.
      net_factory: A function from num_actions to a Haiku module representing
        the agent. This module should have an initial_state() function and an
        unroll function.
    """
    #set observation spec
    self._obs_spec = obs_spec

    #set nn being used- for non-CatchNet use second version
    net_factory = functools.partial(net_factory, num_actions)
    #net_factory = functools.partial(net_factory, num_actions)
    
    #start the nn in initial starting space
    _, self._initial_state_apply_fn = hk.without_apply_rng(
        hk.transform(
            lambda batch_size: net_factory().initial_state(batch_size)))

    #create a quick way to call nn on data later
    self._init_fn, self._apply_fn = hk.without_apply_rng(
        hk.transform(lambda obs, state: net_factory().unroll(obs, state)))
    

  #initialize the agent
  @functools.partial(jax.jit, static_argnums=0)
  def initial_params(self, rng_key):
    """
    Initializes the agent params.
    Args:
      rng_key: initial state of nn
    Returns: initial state function
    """
    #create framework to feed data into nn later
    dummy_inputs = jax.tree_map(lambda t: np.zeros(t.shape, t.dtype),
                                self._obs_spec)
    dummy_inputs = util.preprocess_step(dm_env.restart(dummy_inputs))
    dummy_inputs = jax.tree_map(lambda t: t[None, None, ...], dummy_inputs)
    return self._init_fn(rng_key, dummy_inputs, self.initial_state(1))

  #return initial state
  @functools.partial(jax.jit, static_argnums=(0, 1))
  def initial_state(self, batch_size: Optional[int]):
    """
    Returns agent initial state
    Args:
      batch_size: batch size, int
    Returns:
      function to generate initial state"""
    # We expect that generating the initial_state does not require parameters.
    return self._initial_state_apply_fn(None, batch_size)

  @functools.partial(jax.jit, static_argnums=(0,))
  def step(
      self,
      rng_key,
      params: hk.Params,
      timestep: dm_env.TimeStep,
      state: Nest,
  ) -> Tuple[AgentOutput, Nest]:
    """For a given single-step, output the chosen action"""
    # Pad timestep, state to be [T, B, ...] and [B, ...] respectively.
    timestep = jax.tree_map(lambda t: t[None, None, ...], timestep)
    state = jax.tree_map(lambda t: t[None, ...], state)

    net_out, next_state = self._apply_fn(params, timestep, state)
    # Remove the padding from above.
    net_out = jax.tree_map(lambda t: jnp.squeeze(t, axis=(0, 1)), net_out)
    next_state = jax.tree_map(lambda t: jnp.squeeze(t, axis=0), next_state)
    # Sample an action and return.
    action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1)
    action = jnp.squeeze(action, axis=-1)
    return AgentOutput(net_out.policy_logits, net_out.value, action), next_state

  def unroll(
      self,
      params: hk.Params,
      trajectory: dm_env.TimeStep,
      state: Nest,
  ) -> AgentOutput:
    """Unroll the agent along trajectory."""
    net_out, _ = self._apply_fn(params, trajectory, state)
    return AgentOutput(net_out.policy_logits, net_out.value, action=[])

# VTrace


In [None]:
#create vtrace output array
Array = chex.Array
VTraceOutput = collections.namedtuple(
    'vtrace_output', ['errors', 'pg_advantage', 'q_estimate'])


In [None]:
def vtrace_td_error_and_advantage(
    c_help: Array,
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    clip_pg_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> VTraceOutput:
  """Calculates V-Trace errors and PG advantage from importance weights.
  This functions computes the TD-errors and policy gradient Advantage terms
  as used by the IMPALA distributed actor-critic agent.
  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561)
  Args:
    c_help: target policy values
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance weights at time t.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance ratios.
    clip_pg_rho_threshold: clip threshold for policy gradient importance ratios.
    stop_target_gradients: whether or not to apply stop gradient to targets.
  Returns:
    a tuple of V-Trace error, policy gradient advantage, and estimated Q-values.
  """
  #check shapes are correct
  chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1], 1)
  chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1], float)
  chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

  #calculate TD error
  errors = vtrace(
      c_help, v_tm1, v_t, r_t, discount_t, rho_tm1,
      lambda_, clip_rho_threshold, stop_target_gradients)
  targets_tm1 = errors + v_tm1
  #calculate q bootstrap value
  q_bootstrap = jnp.concatenate([
      lambda_ * targets_tm1[1:] + (1 - lambda_) * v_tm1[1:],
      v_t[-1:],
  ], axis=0)
  #estimate updated q value
  q_estimate = r_t + discount_t * q_bootstrap
  #estimate pg advantage for loss function
  clipped_pg_rho_tm1 = jnp.minimum(clip_pg_rho_threshold, rho_tm1)
  pg_advantages = clipped_pg_rho_tm1 * (q_estimate - v_tm1)
  #return vtrace outputs
  return VTraceOutput(
      errors=errors, pg_advantage=pg_advantages, q_estimate=q_estimate)


In [None]:
def vtrace(
    c_help: Array,
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_tm1: Array,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> Array:
  """Calculates V-Trace errors from importance weights.
  V-trace computes TD-errors from multistep trajectories by applying
  off-policy corrections based on clipped importance sampling ratios.
  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561).
  Args:
    c_help: target policy values 
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_tm1: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.
  Returns:
    V-Trace error.
  """
  #check shapes of everything
  chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_tm1], [1, 1, 1, 1, 1])
  chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_tm1],
                   [float, float, float, float, float])
  chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_tm1])

  # Clip importance sampling ratios if needed-- change depending on IS algo
  clipped_rhos = rho_tm1 
  c_t = rho_tm1

  #lambda num-- change depending on IS algo
  lambda_num = 1.

  # Compute the temporal difference errors.
  td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

  # Work backwards computing the td-errors.
  err = 0.0
  errors = []
  for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
    err = td_errors[i] + discount_t[i] * lambda_num * err
    errors.insert(0, err)

  # Return errors, maybe disabling gradient flow through bootstrap targets.
  return jax.lax.select(
      stop_target_gradients,
      jax.lax.stop_gradient(jnp.array(errors) + v_tm1) - v_tm1,
      jnp.array(errors))


# Learner

In [None]:

def policy_gradient_loss(logits, *args):
  """calculate policy gradient loss
  Args:
    logits: policy probabilities
  Returns:
    policy gradient loss"""
  # calculate mean for batch
  mean_per_batch = jax.vmap(rlax.policy_gradient_loss, in_axes=1)(logits, *args)
  # multiply according to probabilities
  total_loss_per_batch = mean_per_batch * logits.shape[0]
  # return sum of those means
  return jnp.sum(total_loss_per_batch)


def entropy_loss(logits, *args):
  """calculate entropy loss
  Args:
    logits: policy probabilities
  Return: 
    policy gradient loss"""
  # calculate mean for batch
  mean_per_batch = jax.vmap(rlax.entropy_loss, in_axes=1)(logits, *args)
  # multiply according to probabilities
  total_loss_per_batch = mean_per_batch * logits.shape[0]
  # return sum of those means
  return jnp.sum(total_loss_per_batch)


class Learner:
  """Manages state and performs updates for IMPALA learner."""

  def __init__(
      self,
      agent: Agent,
      rng_key,
      opt: optax.GradientTransformation,
      batch_size: int,
      discount_factor: float,
      frames_per_iter: int,
      learnersActor,
      max_abs_reward: float = 0,
      logger=None,
  ):
    if jax.device_count() > 1:
      warnings.warn('Note: the impala example will only take advantage of a '
                    'single accelerator.')

    # initialize vars
    self._agent = agent
    self._opt = opt
    self._batch_size = batch_size
    self._discount_factor = discount_factor
    self._frames_per_iter = frames_per_iter
    self._max_abs_reward = max_abs_reward
    self._learners_actor = learnersActor

    self._all_total_loss = {"total_loss":[], "PG_loss":[], "baseline_loss":[], "entropy_loss":[],
                            "grad_norm_unclipped":[], "weight_norm":[],
                            "num_frames":[], "epoch":[], "error":[], "q":[],
                            "error_length":[], "sum":[]}



    # Data pipeline objects.
    self._done = False
    self._host_q = queue.Queue(maxsize=self._batch_size)
    self._device_q = queue.Queue(maxsize=1)
    self.logit_log = None

    # Prepare the parameters to be served to actors.
    params = agent.initial_params(rng_key)
    self._params_for_actor = (0, jax.device_get(params))

    # keep track of data-collection vars
    self._best_return = 0
    self._best_visual_return = 0
    self._tracker = 0
    self._discount_log = 0

    # Set up logging.
    if logger is None:
      logger = util.AbslLogger()
    self._logger = logger

  def _loss(
      self,
      theta: hk.Params,
      trajectories: util.Transition,
  ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
    """Compute vtrace-based actor-critic loss.
    Args:
      theta: current params
      trajectories: batch collected by actors
    Returns:
      total_loss: loss for all the updates
      logs: data logs"""
    # feed batched actor info into nn for learner outputs
    initial_state = jax.tree_map(lambda t: t[0], trajectories.agent_state)
    learner_outputs = self._agent.unroll(theta, trajectories.timestep,
                                         initial_state)
    v_t = learner_outputs.values[1:]
    # Remove bootstrap timestep from non-timesteps.
    _, actor_out, _ = jax.tree_map(lambda t: t[:-1], trajectories)
    learner_outputs = jax.tree_map(lambda t: t[:-1], learner_outputs)
    v_tm1 = learner_outputs.values

    # Get the discount, reward, step_type from the next timestep.
    timestep = jax.tree_map(lambda t: t[1:], trajectories.timestep)
    discounts = timestep.discount * self._discount_factor
    self._discount_log = timestep.discount
    rewards = timestep.reward
    if self._max_abs_reward > 0:
      rewards = jnp.clip(rewards, -self._max_abs_reward, self._max_abs_reward)

    # check to see if we've gone back to first step (failed and restarted)
    mask = jnp.not_equal(timestep.step_type, int(dm_env.StepType.FIRST))
    mask = mask.astype(jnp.float32)

    #calculate rhos and target policy probabilities for importance sampling
    rhos = rlax.categorical_importance_sampling_ratios(
        learner_outputs.policy_logits, actor_out.policy_logits,
        actor_out.action)
    c_help = base.batched_index(jax.nn.log_softmax(learner_outputs.policy_logits), actor_out.action)
    
    # calculate vtrace
    vtrace_td_error_and_advantage_map = jax.vmap(
        vtrace_td_error_and_advantage, in_axes=1, out_axes=1)
    vtrace_returns = vtrace_td_error_and_advantage_map(
        c_help, v_tm1, v_t, rewards, discounts, rhos)
    
    #use vtrace outputs to calculate losses
    error = vtrace_returns.errors
    q = vtrace_returns.q_estimate
    pg_advs = vtrace_returns.pg_advantage
    pg_loss = policy_gradient_loss(learner_outputs.policy_logits,
                                   actor_out.action, pg_advs, mask)
    baseline_loss = 0.5 * jnp.sum(jnp.square(vtrace_returns.errors) * mask)
    ent_loss = entropy_loss(learner_outputs.policy_logits, mask)

    #add together for overall loss
    total_loss = pg_loss
    total_loss += 0.5 * baseline_loss
    total_loss += 0.01 * ent_loss

    #data logs
    logs = {}
    logs['PG_loss'] = pg_loss
    logs['baseline_loss'] = baseline_loss
    logs['entropy_loss'] = ent_loss
    logs['total_loss'] = total_loss
    logs['error'] = vtrace_returns.errors
    logs['q'] = q
    logs['error_length'] = jnp.sum(vtrace_returns.errors) / len(vtrace_returns.errors)
    logs['sum'] = jnp.sum(vtrace_returns.errors)

    return total_loss, logs

  @functools.partial(jax.jit, static_argnums=0)
  def update(self, params, opt_state, batch: util.Transition):
    """The actual update function.
    Args:
      params: current policy params
      opt_state: current optimal state
      batch: batch from actors
    Returns:
      new params, updated optimal state, data logs"""

    # calculate total loss
    (loss_val, logs), grads = jax.value_and_grad(
        self._loss, has_aux=True)(params, batch)

    # use total loss to update policy 
    grad_norm_unclipped = optimizers.l2_norm(grads)
    updates, updated_opt_state = self._opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    weight_norm = optimizers.l2_norm(params)

    #data logs
    logs.update({
        'grad_norm_unclipped': grad_norm_unclipped,
        'weight_norm': weight_norm,
    })
    
    return params, updated_opt_state, logs

  def enqueue_traj(self, traj: util.Transition):
    """Enqueue trajectory from actor
    Args:
      atraj: actor trajectory"""
    self._host_q.put(traj)

  def best_return(self, actor_return):
    """update best actor return if needed
    Args:
      actor_return: best return from actor trajectories"""
    # if this actor produced the best return, update it
    if(actor_return > self._best_return):
      self._best_return = actor_return

  def best_visual_return(self, actor_return):
     """update best visual actor return if needed
    Args:
      actor_return: best return from the visual actor trajectories"""
    # if this actor produced the best return out of the visual actors, update it
    if(actor_return > self._best_visual_return):
      self._best_visual_return = actor_return

  def params_for_actor(self) -> Tuple[int, hk.Params]:
    """return current actor params"""
    return self._params_for_actor

  def host_to_device_worker(self):
    """Elementary data pipeline."""
    batch = []
    while not self._done:
      # Try to get a batch. Skip the iteration if we couldn't.
      try:
        for _ in range(len(batch), self._batch_size):
          # As long as possible while keeping learner_test time reasonable.
          batch.append(self._host_q.get(timeout=10))
      except queue.Empty:
        continue

      assert len(batch) == self._batch_size
      # Prepare for consumption, then put batch onto device.
      stacked_batch = jax.tree_multimap(lambda *xs: np.stack(xs, axis=1),
                                        *batch)
      self._device_q.put(jax.device_put(stacked_batch))

      # Clean out the built-up batch.
      batch = []

  def run(self, max_iterations: int = -1):
    """Runs the learner for max_iterations updates
    Args:
      max_iterations: how many times to run learner update
    """
    # Start host-to-device transfer worker.
    transfer_thread = threading.Thread(target=self.host_to_device_worker)
    transfer_thread.start()

    # get current params and optimal state
    (num_frames, params) = self._params_for_actor
    opt_state = self._opt.init(params)

    # iterate as many times as needed
    steps = range(max_iterations) if max_iterations != -1 else itertools.count()
    for epoch in tqdm(steps):
      # get actor trajectories
      batch = self._device_q.get()

      # update actor params and optimal state
      params, opt_state, logs = self.update(params, opt_state, batch)

      # move time tracker to next frame sequence
      num_frames += self._frames_per_iter

      # Collect parameters to distribute to downstream actors
      self._params_for_actor = (num_frames, jax.device_get(params))

      # Collect and write logs out
      logs = jax.device_get(logs)
      element_names = ["total_loss", "PG_loss", "baseline_loss", "entropy_loss", "grad_norm_unclipped", "weight_norm",
       "num_frames", "epoch", "error", "q", "error_length", "sum"]
      
      logs.update({
          'num_frames': num_frames,
          'epoch': epoch,
      })

      for element in element_names:
        new_element = {element: np.append(self._all_total_loss[element], logs[element])}
        self._all_total_loss.update(new_element)

      self._logger.write(logs)

    # Shut down.
    self._done = True
    self._logger.close()
    transfer_thread.join()

# Actor

In [None]:
class SeparateActor:
  """Manages the state of a single agent/environment interaction loop to keep track of actor returns"""

  def __init__(
      self,
      agent: Agent,
      env: dm_env.Environment,
      unroll_length: int,
      rng_seed: int = 42,
      logger=None,
  ):
    # initialize vars
    self._agent = agent
    self._env = env
    self._unroll_length = unroll_length
    self._timestep = env.reset()
    self._agent_state = agent.initial_state(None)
    self._traj = []
    self._rng_key = jax.random.PRNGKey(rng_seed)
    self._best_return = 0 
    self._all_returns = []
    self._all_frames = []
    self._num_runs = 0

    if logger is None:
      logger = util.AbslLogger()
    self._logger = logger

    self._episode_return = 0.
    self._tracker = 0

  def unroll(self, rng_key, frame_count: int, params: hk.Params,
             unroll_length: int) -> util.Transition:
    """Run unroll_length agent/environment steps, returning the trajectory
    Args:
      rng_key: rng_key
      frame_count: what frame you're on
      params: current actor params
      unroll_length: how far to unroll"""
    
    # set actor to current time and state
    timestep = self._timestep
    agent_state = self._agent_state
    # Unroll one longer if trajectory is empty.
    num_interactions = unroll_length + int(not self._traj)
    subkeys = jax.random.split(rng_key, num_interactions)
    # data logs
    self._best_return = 0
    self._tracker += 1
    # loop through unroll length collecting data
    for i in range(num_interactions):
      # load timestep
      timestep = util.preprocess_step(timestep)
      # get next state
      agent_out, next_state = self._agent.step(subkeys[i], params, timestep,
                                               agent_state)
      # transtion to that state
      transition = util.Transition(
          timestep=timestep,
          agent_out=agent_out,
          agent_state=agent_state)
      # keep track of that data
      self._traj.append(transition)
      # move to next state + timestep
      agent_state = next_state
      timestep = self._env.step(agent_out.action)

      # if it's the last timestep, save rewards and logs
      if timestep.last():
        # save reward
        self._episode_return += timestep.reward
        self._num_runs = self._num_runs + 1
        # data logs
        self._logger.write({
           'num_frames': frame_count,
           'episode_return': self._episode_return,
        })
        # if best return from this actor, save it
        if(self._episode_return > self._best_return):
          self._best_return = self._episode_return
        # set reward to 0 to start again
        self._episode_return = 0.
      else:
        # add 1 for staying upright, or 0 for falling
        self._episode_return += timestep.reward or 0.
      
      # keep track of rewards and times
      self._all_returns.append(self._episode_return)
      self._all_frames.append(frame_count)


  def unroll_without_push(self, frame_count, params):
    """Run one unroll and send trajectory to learner
    Args:
      frame_count: which frame/time actor is in from learner
      params: actor params from learner
      """
    
    # keep track of new params
    params = jax.device_put(params)
    self._rng_key, subkey = jax.random.split(self._rng_key)
    # create 1 set of data
    act_out = self.unroll(
        rng_key=subkey,
        frame_count=frame_count,
        params=params,
        unroll_length=self._unroll_length)

  def pull_params(self):
    """pull the new params from the learner"""
    return self._learner.params_for_actor()


In [None]:
import dm_env
import haiku as hk
from examples.impala import agent as agent_lib
from examples.impala import learner as learner_lib
from examples.impala import util
import jax
import numpy as np


class Actor:
  """Manages the state of a single agent/environment interaction loop. This is the basic actor."""

  def __init__(
      self,
      agent: agent_lib.Agent,
      env: dm_env.Environment,
      unroll_length: int,
      learner: learner_lib.Learner,
      rng_seed: int = 42,
      logger=None,
  ):
    #set vars
    self._agent = agent
    self._env = env
    self._unroll_length = unroll_length
    self._learner = learner
    self._timestep = env.reset()
    self._agent_state = agent.initial_state(None)
    self._traj = []
    self._rng_key = jax.random.PRNGKey(rng_seed)
    self._episode_return = 0.

    #create logger
    if logger is None:
      logger = util.NullLogger()
    self._logger = logger

  def unroll(self, rng_key, frame_count: int, params: hk.Params,
             unroll_length: int) -> util.Transition:
    """Run unroll_length agent/environment steps, returning the trajectory
    Args:
      rng_key: rng_key
      frame_count: what frame actor is on
      params: params pulled from learner
      unroll_length: how many data samples to collect
    Returns
      trajectory: data from actor"""
    # set to current timestep and state
    timestep = self._timestep
    agent_state = self._agent_state
    # Unroll one longer if trajectory is empty.
    num_interactions = unroll_length + int(not self._traj)
    subkeys = jax.random.split(rng_key, num_interactions)
    # loop through number of data collections needed
    for i in range(num_interactions):
      # load next timestep
      timestep = util.preprocess_step(timestep)
      # get next state
      agent_out, next_state = self._agent.step(subkeys[i], params, timestep,
                                               agent_state)
      # transition to next state
      transition = util.Transition(
          timestep=timestep,
          agent_out=agent_out,
          agent_state=agent_state)
      # keep track of that transition info
      self._traj.append(transition)
      # move to next state/timestep
      agent_state = next_state
      timestep = self._env.step(agent_out.action)

      # if last step, return total reward
      if timestep.last():
        # add last timestep reward
        self._episode_return += timestep.reward
        # log data
        self._logger.write({
           'num_frames': frame_count,
           'episode_return': self._episode_return,
        })
        # set to 0 to begin cycle again
        self._episode_return = 0.
      else:
        # add reward for current timestep: +1 for staying upright
        self._episode_return += timestep.reward or 0.

    # organize data for learner
    trajectory = jax.device_get(self._traj)
    trajectory = jax.tree_multimap(lambda *xs: np.stack(xs), *trajectory)
    self._timestep = timestep
    self._agent_state = agent_state
    # Keep the bootstrap timestep for next trajectory.
    self._traj = self._traj[-1:]
    return trajectory

  def unroll_and_push(self, frame_count: int, params: hk.Params):
    """Run one unroll and send trajectory to learner.
    Args:
      frame_count: current frame count from learner
      params: current params from learner"""
    # keep track of current params
    params = jax.device_put(params)
    self._rng_key, subkey = jax.random.split(self._rng_key)
    # collect necessary data
    act_out = self.unroll(
        rng_key=subkey,
        frame_count=frame_count,
        params=params,
        unroll_length=self._unroll_length)
    # send learner the data
    self._learner.enqueue_traj(act_out)

  def pull_params(self):
    """pull new params from learner"""
    return self._learner.params_for_actor()

In [None]:
class ActorVisual:
  """Manages the state of a single agent/environment interaction loop and produces visual for CartPole"""

  def __init__(
      self,
      agent: Agent,
      env: dm_env.Environment,
      unroll_length: int,
      learner: Learner,
      rng_seed: int = 42,
      logger=None,
  ):
    #set necessary vars
    self._agent = agent
    self._env = env
    self._unroll_length = unroll_length
    self._learner = learner
    self._timestep = env.reset()
    self._agent_state = agent.initial_state(None)
    self._traj = []
    self._rng_key = jax.random.PRNGKey(rng_seed)

    self._images = None
    self._viewer = None

    # data log
    if logger is None:
      logger = util.AbslLogger()
    self._logger = logger

    self._episode_return = 0.
    self.best_return = 0

  def unroll(self, rng_key, frame_count: int, params: hk.Params,
             unroll_length: int) -> util.Transition:
    """Run unroll_length agent/environment steps, returning the trajectory
    Args:
      rng_key: rng_key
      frame_count: pulled from learner, frame we're on
      params: params pulled from learner
      unroll_length: how much data to collect
    Returns:
      collected data"""

    # start screen rendering
    screen = self.render(mode='rgb_array')
    im = Image.fromarray(screen)
    temp_images = [im]

    # set current time step and state
    timestep = self._timestep
    agent_state = self._agent_state
    # Unroll one longer if trajectory is empty.
    num_interactions = unroll_length + int(not self._traj)
    subkeys = jax.random.split(rng_key, num_interactions)
    self._best_return = 0
    # loop through amount of data needed
    for i in range(num_interactions):
      # load timestep
      timestep = util.preprocess_step(timestep)
      #get next state
      agent_out, next_state = self._agent.step(subkeys[i], params, timestep,
                                               agent_state)
      # transition to next state
      transition = util.Transition(
          timestep=timestep,
          agent_out=agent_out,
          agent_state=agent_state)
      # keep track of transition data
      self._traj.append(transition)
      # move to next state + timestep
      agent_state = next_state
      timestep = self._env.step(agent_out.action)

      # every few loops save an image for the gif
      if i % 10 == 0:
        screen = self.render(mode='rgb_array')
        temp_images.append(Image.fromarray(screen))

      # if last timestep, save reward and start over
      if timestep.last():
        # add reward for timestep
        self._episode_return += timestep.reward
        # log data
        self._logger.write({
           'num_frames': frame_count,
           'episode_return': self._episode_return,
        })
        # if best return for this actor, save it and its gif
        if (self._episode_return > self._best_return):
          self._best_return = self._episode_return
          self._images = temp_images
        # restart
        self._episode_return = 0.
      else:
        # add reward if it kept upright
        self._episode_return += timestep.reward or 0.

    # Pack the trajectory and reset parent state.
    trajectory = jax.device_get(self._traj)
    trajectory = jax.tree_multimap(lambda *xs: np.stack(xs), *trajectory)
    self._timestep = timestep
    self._agent_state = agent_state
    # Keep the bootstrap timestep for next trajectory.
    self._traj = self._traj[-1:]
    return trajectory

  def unroll_and_push(self, frame_count: int, params: hk.Params):
    """Run one unroll and send trajectory to learner
    Args:
      frame_count: frame we're on
      params: params pulled from learner"""
    # save params
    params = jax.device_put(params)
    self._rng_key, subkey = jax.random.split(self._rng_key)
    # collect data
    act_out = self.unroll(
        rng_key=subkey,
        frame_count=frame_count,
        params=params,
        unroll_length=self._unroll_length)
    # send data to learner
    self._learner.enqueue_traj(act_out)
    self._learner.best_visual_return(self._best_return)

  def pull_params(self):
    """pull params from learner"""
    return self._learner.params_for_actor()

  def render(self, mode='human'):
    """render the cartpole gif"""
    # set image size
      screen_width = 600
      screen_height = 400

      # draw parts of cartpole
      world_width = self._env._x_threshold  * 2
      scale = screen_width/world_width
      carty = 100  # TOP OF CART
      polewidth = 10.0
      polelen = scale * (2 * 0.5)
      cartwidth = 50.0
      cartheight = 30.0

      test = self._viewer
      if self._viewer is None:
        # add a new render if it doesn't exist-- render according to data cartpole keeps track of
        from gym.envs.classic_control import rendering
        self._viewer = rendering.Viewer(screen_width, screen_height)
        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
        self._env.carttrans = rendering.Transform()
        cart.add_attr(self._env.carttrans)
        self._viewer.add_geom(cart)
        l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
        pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
        pole.set_color(.8, .6, .4)
        self._env.poletrans = rendering.Transform(translation=(0, axleoffset))
        pole.add_attr(self._env.poletrans)
        pole.add_attr(self._env.carttrans)
        self._viewer.add_geom(pole)
        self._env.axle = rendering.make_circle(polewidth/2)
        self._env.axle.add_attr(self._env.poletrans)
        self._env.axle.add_attr(self._env.carttrans)
        self._env.axle.set_color(.5, .5, .8)
        self._viewer.add_geom(self._env.axle)
        self._env.track = rendering.Line((0, carty), (screen_width, carty))
        self._env.track.set_color(0, 0, 0)
        self._viewer.add_geom(self._env.track)

        self._env._pole_geom = pole

      if self._env._state is None:
          return None

      # Edit the pole polygon vertex
      pole = self._env._pole_geom
      l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
      pole.v = [(l, b), (l, t), (r, t), (r, b)]

      x = self._env._state
      cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
      self._env.carttrans.set_translation(cartx, carty)
      self._env.poletrans.set_rotation(-x[2])

      return self._viewer.render(return_rgb_array=mode == 'rgb_array')

  def close(self):
    """close render"""
      if self._viewer:
          self._viewer.close()
          self._viewer = Non


In [None]:
class ActorGrapher:
  """Manages the state of a single agent/environment interaction loop."""

  def __init__(
      self,
      agent: Agent,
      env: dm_env.Environment,
      unroll_length: int,
      learner: Learner,
      rng_seed: int = 42,
      logger=None,
  ):
    # create necessary vars
    self._agent = agent
    self._env = env
    self._unroll_length = unroll_length
    self._learner = learner
    self._timestep = env.reset()
    self._agent_state = agent.initial_state(None)
    self._traj = []
    self._rng_key = jax.random.PRNGKey(rng_seed)
    self._best_return = 0 
    self._all_returns = []
    self._all_frames = []
    self._num_runs = 0

    # create logger
    if logger is None:
      logger = util.AbslLogger()
    self._logger = logger

    self._episode_return = 0.

  def unroll(self, rng_key, frame_count: int, params: hk.Params,
             unroll_length: int) -> util.Transition:
    """Run unroll_length agent/environment steps, returning the trajectory."""
    self._num_runs = 0
    timestep = self._timestep
    agent_state = self._agent_state
    # Unroll one longer if trajectory is empty.
    num_interactions = unroll_length + int(not self._traj)
    subkeys = jax.random.split(rng_key, num_interactions)
    self._best_return = 0
    # run through and collect data
    for i in range(num_interactions):
      # load timestep
      timestep = util.preprocess_step(timestep)
      # figure out next state
      agent_out, next_state = self._agent.step(subkeys[i], params, timestep,
                                               agent_state)
      # transition to next state
      transition = util.Transition(
          timestep=timestep,
          agent_out=agent_out,
          agent_state=agent_state)
      self._traj.append(transition)
      agent_state = next_state
      timestep = self._env.step(agent_out.action)

      # if last timestep, save episode return, if not continue on
      if timestep.last():
        self._episode_return += timestep.reward
        self._all_returns.append(self._episode_return)
        self._all_frames.append(frame_count)
        self._num_runs = self._num_runs + 1

        self._logger.write({
           'num_frames': frame_count,
           'episode_return': self._episode_return,
        })
        if(self._episode_return > self._best_return):
          self._best_return = self._episode_return
        self._episode_return = 0.
      else:
        self._episode_return += timestep.reward or 0.

    # Pack the trajectory and reset parent state.
    trajectory = jax.device_get(self._traj)
    trajectory = jax.tree_multimap(lambda *xs: np.stack(xs), *trajectory)
    self._timestep = timestep
    self._agent_state = agent_state
    # Keep the bootstrap timestep for next trajectory.
    self._traj = self._traj[-1:]
    return trajectory

  def unroll_and_push(self, frame_count: int, params: hk.Params):
    """Run one unroll and send trajectory to learner."""
    params = jax.device_put(params)
    self._rng_key, subkey = jax.random.split(self._rng_key)
    # collect needed amount of data
    act_out = self.unroll(
        rng_key=subkey,
        frame_count=frame_count,
        params=params,
        unroll_length=self._unroll_length)
    self._learner.enqueue_traj(act_out)

  def pull_params(self):
    # pull data from learner
    return self._learner.params_for_actor()


# Common Networks

In [None]:
# keep track of nn variables
NetOutput = collections.namedtuple('NetOutput', ['policy_logits', 'value'])

class CatchNet(hk.RNNCore):
  """The easiest IMPALA nn"""

  def __init__(self, num_actions, name=None):
    """basic init function
    Args:
    num_actions: number of actions that can be taken in an env"""
    super(CatchNet, self).__init__(name=name)
    self._num_actions = num_actions

  def initial_state(self, batch_size):
    """set up for nn, just the shape
    Args:
      batch_size: learner batch size"""
    if batch_size is None:
      shape = []
    else:
      shape = [batch_size]
    return jnp.zeros(shape)  

  def __call__(self, x: dm_env.TimeStep, state):
    """one loop of the NN
    Args:
    x: timestep class for current time
    state: the current state"""
    # feed through linear function
    torso_net = hk.Sequential(
        [hk.Flatten(),
         hk.Linear(128), jax.nn.relu,
         hk.Linear(64), jax.nn.relu])
    torso_output = torso_net(x.observation)
    policy_logits = hk.Linear(self._num_actions)(torso_output)
    value = hk.Linear(1)(torso_output)
    value = jnp.squeeze(value, axis=-1)
    # output values
    return NetOutput(policy_logits=policy_logits, value=value), state

  def unroll(self, x, state):
    """Apply nn to all data
    Args:
      x: timesteps
      state: states"""
    out, _ = hk.BatchApply(self)(x, None)
    return out, state


class AtariShallowTorso(hk.Module):
  """Shallow torso for Atari, from the DQN paper."""

  def __init__(self, name=None):
    super(AtariShallowTorso, self).__init__(name=name)

  def __call__(self, x):
    # put data through nn
    torso_net = hk.Sequential([
        lambda x: x / 255.,
        hk.Conv2D(32, kernel_shape=[8, 8], stride=[4, 4], padding='VALID'),
        jax.nn.relu,
        hk.Conv2D(64, kernel_shape=[4, 4], stride=[2, 2], padding='VALID'),
        jax.nn.relu,
        hk.Conv2D(64, kernel_shape=[3, 3], stride=[1, 1], padding='VALID'),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(512),
        jax.nn.relu,
    ])
    return torso_net(x)


class ResidualBlock(hk.Module):
  """Residual block."""

  def __init__(self, num_channels, name=None):
    super(ResidualBlock, self).__init__(name=name)
    self._num_channels = num_channels

  def __call__(self, x):
    # feed through Residual Block
    main_branch = hk.Sequential([
        jax.nn.relu,
        hk.Conv2D(
            self._num_channels,
            kernel_shape=[3, 3],
            stride=[1, 1],
            padding='SAME'),
        jax.nn.relu,
        hk.Conv2D(
            self._num_channels,
            kernel_shape=[3, 3],
            stride=[1, 1],
            padding='SAME'),
    ])
    return main_branch(x) + x


class AtariDeepTorso(hk.Module):
  """Deep torso for Atari, from the IMPALA paper."""

  def __init__(self, name=None):
    super(AtariDeepTorso, self).__init__(name=name)

  def __call__(self, x):
    # go through convs, res. blocks and relus
    torso_out = x / 255.
    for i, (num_channels, num_blocks) in enumerate([(16, 2), (32, 2), (32, 2)]):
      conv = hk.Conv2D(
          num_channels, kernel_shape=[3, 3], stride=[1, 1], padding='SAME')
      torso_out = conv(torso_out)
      torso_out = hk.max_pool(
          torso_out,
          window_shape=[1, 3, 1],
          strides=[1, 2, 1],
          padding='SAME',
      )
      for j in range(num_blocks):
        block = ResidualBlock(num_channels, name='residual_{}_{}'.format(i, j))
        torso_out = block(torso_out)

    torso_out = jax.nn.relu(torso_out)
    torso_out = hk.Flatten()(torso_out)
    torso_out = hk.Linear(256)(torso_out)
    torso_out = jax.nn.relu(torso_out)
    return torso_out


class AtariNet(hk.RNNCore):
  """Network for Atari."""

  def __init__(self, num_actions, use_resnet, use_lstm, name=None):
    super(AtariNet, self).__init__(name=name)
    self._num_actions = num_actions
    self._use_resnet = use_resnet
    self._use_lstm = use_lstm
    self._core = hk.ResetCore(hk.LSTM(256))

  def initial_state(self, batch_size):
    return self._core.initial_state(batch_size)

  def __call__(self, x: dm_env.TimeStep, state):
    x = jax.tree_map(lambda t: t[None, ...], x)
    return self.unroll(x, state)

  def unroll(self, x, state):
    """Unrolls more efficiently than dynamic_unroll."""
    if self._use_resnet:
      torso = AtariDeepTorso()
    else:
      torso = AtariShallowTorso()

    torso_output = hk.BatchApply(torso)(x.observation)
    if self._use_lstm:
      should_reset = jnp.equal(x.step_type, int(dm_env.StepType.FIRST))
      core_input = (torso_output, should_reset)
      core_output, state = hk.dynamic_unroll(self._core, core_input, state)
    else:
      core_output = torso_output
      # state passes through.

    return hk.BatchApply(self._head)(core_output), state

  def _head(self, core_output):
    policy_logits = hk.Linear(self._num_actions)(core_output)
    value = hk.Linear(1)(core_output)
    value = jnp.squeeze(value, axis=-1)
    return NetOutput(policy_logits=policy_logits, value=value)

# Main

In [None]:
#set necessary vars
ACTION_REPEAT = 4
BATCH_SIZE = 32
DISCOUNT_FACTOR = 0.99
NUM_ACTORS = 3
UNROLL_LENGTH = 20
FRAMES_PER_ITER = ACTION_REPEAT * BATCH_SIZE * UNROLL_LENGTH
MAX_ENV_FRAMES = FRAMES_PER_ITER * 400
MAX_TIME = 10



def run_actor(actor, stop_signal: List[bool]):
  """Runs an actor to produce num_trajectories trajectories
  Args:
    actor: an actor to collect samples
    stop_signal: whether the actor should keep running"""
  # actor collects data until told otherwise
  while not stop_signal[0]:
    # actor updates its frame count and params from learner
    frame_count, params = actor.pull_params()
    # actor collects data and sends back to learner
    actor.unroll_and_push(frame_count, params)

# build cartpole environment
build_env = cartpole.Cartpole

# construct agent
env_for_spec = build_env(max_time = MAX_TIME)
num_actions = env_for_spec.action_spec().num_values
agent = Agent(num_actions, env_for_spec.observation_spec(),
                          CatchNet)

# Calculate number of updates learner will do
max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER
print("Running ", max_updates, " iterations for learner")
# Construct optimizer
opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)

# create actor to keep track of learner changes
learnersActor = SeparateActor(
      agent,
      build_env(max_time = MAX_TIME),
      UNROLL_LENGTH,
      rng_seed=0,
      logger=util.AbslLogger(),  
  )

# Construct the learner
print("making learner now...")
learner = Learner(
    agent,
    jax.random.PRNGKey(428),
    opt,
    BATCH_SIZE,
    DISCOUNT_FACTOR,
    FRAMES_PER_ITER,
    learnersActor,
    max_abs_reward=1.,
    logger=util.AbslLogger(),  # Provide your own logger here.
)

# Construct the actors on different threads
print("making actors now...")
actor_threads = []

stop_signal = [False]

# create the actor with image capabilities
actorVisual = ActorGrapher(
      agent,
      build_env(max_time = MAX_TIME),
      UNROLL_LENGTH,
      learner,
      rng_seed=0,
      logger=util.AbslLogger(),  
  )
args = (actorVisual, stop_signal)
actor_threads.append(threading.Thread(target=run_actor, args=args))

# create the regular actors
for i in range(1, NUM_ACTORS):
  actor = Actor(
      agent,
      build_env(max_time = MAX_TIME),
      UNROLL_LENGTH,
      learner,
      rng_seed=i,
      logger=util.AbslLogger(),  # Provide your own logger here.
  )
  args = (actor, stop_signal)
  actor_threads.append(threading.Thread(target=run_actor, args=args))

# Start the actors and learner
print("starting the learning!")
for t in actor_threads:
  t.start()
learner.run(int(max_updates))


  # Stop.
stop_signal[0] = True
for t in actor_threads:
 t.join()

Running  400.0  iterations for learner
making learner now...
making actors now...
starting the learning!


100%|██████████| 400/400 [1:03:19<00:00,  9.50s/it]


In [None]:
# initialize necessary vars
current_frame = 0
current_num = 0.0

all_averages = []
all_frames = []

# whenever the tracking actor restarts an episode, log its episode return
for frame_num in range(len(learnersActor._all_frames)):
  if (learnersActor._all_returns[frame_num] == 0.0):
    current_num = 0.0
    if (learnersActor._all_returns[frame_num-1] != 0.0):
      # if episode just ended, look back 1 step to see highest value
      all_averages.append(learnersActor._all_returns[frame_num-1])
      all_frames.append(learnersActor._all_frames[frame_num-1])
  else:
    current_num +=1

# print episode returns and corresponding frame numbers
print("Episode Returns: ")
print(all_averages)
print("Frame numbers: ")
print(all_frames)

Averages: 
[72.0, 27.0, 30.0, 29.0, 30.0, 29.0, 27.0, 29.0, 27.0, 28.0, 28.0, 29.0, 43.0, 30.0, 28.0, 27.0, 30.0, 29.0, 28.0, 30.0, 29.0, 29.0, 29.0, 27.0, 29.0, 29.0, 27.0, 28.0, 30.0, 29.0, 28.0, 28.0, 29.0, 29.0, 28.0, 29.0, 29.0, 30.0, 29.0, 27.0, 30.0, 29.0, 28.0, 28.0, 29.0, 28.0, 29.0, 30.0, 28.0, 31.0, 31.0, 32.0, 34.0, 65.0, 28.0, 30.0, 28.0, 27.0, 28.0, 28.0, 30.0, 28.0, 30.0, 30.0, 29.0, 30.0, 28.0, 28.0, 30.0, 29.0, 30.0, 28.0, 28.0, 30.0, 29.0, 33.0, 84.0, 44.0, 72.0, 118.0, 76.0, 56.0, 61.0, 74.0, 57.0, 66.0, 113.0, 149.0, 156.0, 171.0, 172.0, 192.0, 148.0, 150.0, 185.0, 351.0, 245.0, 211.0, 217.0, 234.0, 447.0, 270.0, 309.0, 279.0, 403.0, 328.0]
Frame numbers: 
[7680, 10240, 15360, 20480, 23040, 28160, 30720, 35840, 38400, 43520, 46080, 51200, 56320, 61440, 64000, 69120, 71680, 76800, 79360, 84480, 87040, 92160, 94720, 99840, 102400, 107520, 110080, 115200, 120320, 122880, 128000, 130560, 135680, 138240, 143360, 145920, 151040, 153600, 158720, 161280, 166400, 171520, 174