Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New LocalLayout may deadlock block on sample #233

Closed
ethanluoyc opened this issue May 24, 2022 · 36 comments
Closed

New LocalLayout may deadlock block on sample #233

ethanluoyc opened this issue May 24, 2022 · 36 comments

Comments

@ethanluoyc
Copy link
Contributor

ethanluoyc commented May 24, 2022

Hi,

I recently started migrating my JAX agents to use the new LocalLayout, which incorporates the changes that simplify the setup for ensuring that running non-distributed agents would not block. I have noticed that I start to experience deadlock with my old parameters.

I have a pixel-based agent similar to D4PG. For reverb, I use a batch size of 256, sampler_per_insert of 128.0, and a sample per insert tolerance rate of 0.1 (following D4PG with the exception of SPI which is now 128). I use num_sgd_steps = 1. I have noticed that
in this case, I can experience deadlock. The way I set up the data iterator and prefetch follows exactly what's going on in the master branch.

My understanding is that with the new API, things should not block on inserts since the rate limiter has been adjusted to ensure that this cannot happen. For some reason, however, I found that things may block on sample. With some debugging It looks like the deadlock happens here

https://github.com/deepmind/acme/blob/ad073a85319435246a5fac21978e1be655191778/acme/agents/agent.py#L106

I have tried to debug this problem without any luck. @qstanczyk I noticed that you made these changes recently, do you have any idea what may be causing this issue? AFAIK The difference compared to the previous version is that now the local agent uses the table's rate limiting behavior to control the actor learner stepping frequency, instead of doing in manually in the Agent, but I suspect that the because of prefetching something does not work out right.

Any tips for debugging this would be greatly appreciated! Happy to include more details on the setup!

Many thanks in advance!

@qstanczyk
Copy link
Collaborator

Hi, the only scenario that comes to my mind is that Learner's step samples more than once from the iterator. Is that the case?

@ethanluoyc
Copy link
Contributor Author

@qstanczyk I don’t think I sampled more than once though. The deadlock also do not happen at a fixed time. Sometimes it happens sooner and sometimes later.

@qstanczyk
Copy link
Collaborator

Can you check exactly where agent is getting stuck? Is it on next() call on the iterator that LocalLayout controls? If so, can you debug what happens in _has_data_for_training method? It must return true while looks like iterator doesn't have more data to sample.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk
OK, I think I hit a case where the iterator is not ready but there are still data for sampling. Not sure why this can happen tho.

def _has_data_for_training(self):
    if self._iterator.ready():
      return True
    for (table, batch_size) in zip(self._replay_tables,
                                   self._batch_size_upper_bounds):
      if not table.can_sample(batch_size):
        return False
    return True


  def update(self):
    # super().update()
    if self._iterator:
      # Perform learner steps as long as iterator has data.
      update_actor = False
      while self._has_data_for_training():
        # Run learner steps (usually means gradient steps).
        iterator_ready = self._iterator.ready()
        table_can_sample = True
        t = self._replay_tables[0]
        bsz_ub = self._batch_size_upper_bounds[0]
        table_can_sample = t.can_sample(bsz_ub)
        print(iterator_ready, table_can_sample, bsz_ub)
        # on blocking gets [False True 261], I am using a batch size of 256
        self._learner_steps += 1
        self._batch_size_upper_bounds = [
            math.ceil(t.info.rate_limiter_info.sample_stats.completed /
                      self._learner_steps) for t in self._replay_tables
        ]
        self._learner.step()
        update_actor = True
      if update_actor:
        # Update the actor weights only when learner was updated.
        self._actor.update()
      return

@qstanczyk
Copy link
Collaborator

This is expected (and that is why _has_data_for_training checks both conditions). It can happen when iterator didn't fetch the data from the table yet, but Reverb table has data to be sampled. In such case call to next will block for some time until data is fetched from the table (but it should not hang).

@ethanluoyc
Copy link
Contributor Author

But it's hanging in my case somehow (or maybe just blocking for a long time). Are there some reverb stats that I can look into to understand what's going on?

@qstanczyk
Copy link
Collaborator

Small change to verify that iterator is accessed correctly. Can you try it and see is assertion is hit?

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 25, 2022

I just tested. I do not hit the assert but still appears deadlock.

Let me grab some statistics from the rate limiter to see what's going on.

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 25, 2022

I added the following in update

  def update(self):
    if self._iterator:
      # Perform learner steps as long as iterator has data.
      update_actor = False
      while self._has_data_for_training():
        ## Added for DEBUG
        while not self._iterator.ready():
          t = self._replay_tables[0]
          num_inserts = t.info.rate_limiter_info.insert_stats.completed
          num_samples = t.info.rate_limiter_info.sample_stats.completed
          samples_per_insert = t.info.rate_limiter_info.samples_per_insert
          # min_size_to_sample = t.info.rate_limiter_info.min_size_to_sample
          print(
            t.info.rate_limiter_info.min_diff,
            num_inserts * samples_per_insert - num_samples,
            t.info.rate_limiter_info.max_diff,
            t.can_sample(self._batch_size_upper_bounds[0]),
          )
          import time
          time.sleep(0.01)
        ## END DEBUG
        # Run learner steps (usually means gradient steps).
        batches_processed = self._iterator.retrieved_elements()
        self._learner.step()
        assert self._iterator.retrieved_elements() == batches_processed + 1, (
            'Learner step must retrieve exactly one '
            'element from the iterator. Otherwise agent can deadlock.')
        self._batch_size_upper_bounds = [
            math.ceil(t.info.rate_limiter_info.sample_stats.completed /
                      (batches_processed + 1)) for t in self._replay_tables
        ]
        update_actor = True
      if update_actor:
        # Update the actor weights only when learner was updated.
        self._actor.update()
      return

At blocking time I get:

230400.0 230400.0 1.7976931348623157e+308 False

I did some calculation by hand and the numbers seem to match my expectation.
I use SPI 128.0 and error tolerance of 0.1. The min_replay_size is 2000. With that, the min_diff
should be 2000 * (1-0.1) * 128 = 230400. The max_diff is changed in the local layout so that seems alright. n_inserts * spi - n_samples is equal to min_diff in this case so I should actually expect that reverb can be sampled (based on https://github.com/deepmind/reverb/blob/master/reverb/cc/rate_limiter.cc#L112)
The current yet the iterator never gets unblocked.

@qstanczyk
Copy link
Collaborator

How iterator is constructed? Do you use multiple workers per iterator?

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 25, 2022

It's basically the same as all other JAX agents. The agent uses a single worker, same as e.g. JAX D4PG, I also only use a single GPU for training.

Here are parts that are probably relevant.

"""DrQ-v2 builder"""
from typing import Callable, Iterator, List, Optional

from acme import adders
from acme import core
from acme import datasets
from acme import specs
from acme.adders import reverb as adders_reverb
from acme.agents.jax import builders
from acme.jax import networks as networks_lib
from acme.jax import utils
from acme.jax import variable_utils
from acme.utils import counting
from acme.utils import loggers
import jax
import optax
import reverb
from reverb import rate_limiters

from ilax.agents.drq_v2 import acting as acting_lib
from ilax.agents.drq_v2 import config as drq_v2_config
from ilax.agents.drq_v2 import learning as learning_lib
from ilax.agents.drq_v2 import networks as drq_v2_networks


class DrQV2Builder(builders.ActorLearnerBuilder):
  """DrQ-v2 Builder."""

  def __init__(
      self,
      config: drq_v2_config.DrQV2Config,
  ):
    self._config = config

  def make_replay_tables(
      self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
    """Create tables to insert data into."""
    samples_per_insert_tolerance = (
        self._config.samples_per_insert_tolerance_rate *
        self._config.samples_per_insert)
    error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
    limiter = rate_limiters.SampleToInsertRatio(
        min_size_to_sample=self._config.min_replay_size,
        samples_per_insert=self._config.samples_per_insert,
        error_buffer=error_buffer,
    )
    replay_table = reverb.Table(
        name=self._config.replay_table_name,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=self._config.max_replay_size,
        rate_limiter=limiter,
        signature=adders_reverb.NStepTransitionAdder.signature(
            environment_spec=environment_spec),
    )
    return [replay_table]

  def make_dataset_iterator(
      self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
    """Create a dataset iterator to use for learning/updating the agent."""
    dataset = datasets.make_reverb_dataset(
        table=self._config.replay_table_name,
        server_address=replay_client.server_address,
        batch_size=self._config.batch_size,
        prefetch_size=self._config.prefetch_size,
    )
    return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])

  def make_adder(self, replay_client: reverb.Client) -> Optional[adders.Adder]:
    """Create an adder which records data generated by the actor/environment.

        Args:
          replay_client: Reverb Client which points to the replay server.
        """
    return adders_reverb.NStepTransitionAdder(
        client=replay_client,
        n_step=self._config.n_step,
        discount=self._config.discount,
    )

  def make_actor(
      self,
      random_key: networks_lib.PRNGKey,
      policy_network: drq_v2_networks.DrQV2PolicyNetwork,
      adder: Optional[adders.Adder] = None,
      variable_source: Optional[core.VariableSource] = None) -> core.Actor:
    """Create an actor instance.
      Args:
        random_key: A key for random number generation.
        policy_network: Instance of a policy network; this should be a callable
          which takes as input observations and returns actions.
        adder: How data is recorded (e.g. added to replay).
        variable_source: A source providing the necessary actor parameters.
    """
    assert variable_source is not None
    device = "cpu"
    variable_client = variable_utils.VariableClient(
        variable_source, "policy", device=device)
    variable_client.update_and_wait()

    return acting_lib.DrQV2Actor(
        policy_network,
        random_key,
        variable_client=variable_client,
        adder=adder,
        backend=device,
    )

  def make_learner(self,
                   random_key: networks_lib.PRNGKey,
                   networks: drq_v2_networks.DrQV2Networks,
                   dataset: Iterator[reverb.ReplaySample],
                   logger: Optional[loggers.Logger] = None,
                   replay_client: Optional[reverb.Client] = None,
                   counter: Optional[counting.Counter] = None) -> core.Learner:
    """Creates an instance of the learner.

        Args:
          random_key: A key for random number generation.
          networks: struct describing the networks needed by the learner; this can
            be specific to the learner in question.
          dataset: iterator over samples from replay.
          replay_client: client which allows communication with replay, e.g. in
            order to update priorities.
          counter: a Counter which allows for recording of counts (learner steps,
            actor steps, etc.) distributed throughout the agent.
          checkpoint: bool controlling whether the learner checkpoints itself.
        """
    del replay_client
    config = self._config
    critic_optimizer = optax.adam(config.learning_rate)
    policy_optimizer = optax.adam(config.learning_rate)
    encoder_optimizer = optax.adam(config.learning_rate)

    sigma_start, sigma_end, sigma_schedule_steps = config.sigma
    observations_per_step = int(config.batch_size / config.samples_per_insert)
    if hasattr(config, "min_observations"):
      min_observations = config.min_observations
    else:
      min_observations = config.min_replay_size
    # Compute the schedule for the learner
    # Learner only starts updating after min_observations number of steps
    sigma_schedule = lambda step: optax.linear_schedule(  # noqa
        sigma_start, sigma_end, sigma_schedule_steps)((step + max(
            min_observations, config.batch_size)) * observations_per_step)

    return learning_lib.DrQV2Learner(
        random_key=random_key,
        dataset=dataset,
        networks=networks,
        sigma_schedule=sigma_schedule,
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        encoder_optimizer=encoder_optimizer,
        augmentation=config.augmentation,
        critic_soft_update_rate=config.critic_q_soft_update_rate,
        discount=config.discount,
        noise_clip=config.noise_clip,
        logger=logger,
        counter=counter,
    )

The agent definition based on locallayout

from typing import Optional

from acme import specs
from acme.jax.layouts import local_layout
from acme.utils import counting
from acme.utils import loggers
import optax

from ilax.agents.drq_v2 import builder
from ilax.agents.drq_v2 import config as drq_v2_config
# from ilax.agents.drq_v2 import local_layout
from ilax.agents.drq_v2 import networks as drq_v2_networks


class DrQV2(local_layout.LocalLayout):
  """Data-regularized Q agent version 2."""

  builder: builder.DrQV2Builder

  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      networks: drq_v2_networks.DrQV2Networks,
      config: drq_v2_config.DrQV2Config,
      seed: int,
      counter: Optional[counting.Counter] = None,
      logger: Optional[loggers.Logger] = None,
  ):
    drq_v2_builder = builder.DrQV2Builder(config)
    policy_network = drq_v2_networks.get_default_behavior_policy(
        networks, environment_spec.actions,
        optax.linear_schedule(*config.sigma))
    self.builder = drq_v2_builder
    super().__init__(
        seed=seed,
        environment_spec=environment_spec,
        builder=drq_v2_builder,
        networks=networks,
        policy_network=policy_network,
        # min_replay_size=min_replay_size,
        batch_size=config.batch_size,
        workdir=None,
        num_sgd_steps_per_step=1,
        learner_logger=logger,
        counter=counter,
        checkpoint=False)

config

@dataclasses.dataclass
class DrQV2Config:
  """Configuration parameters for DrQ."""

  augmentation: augmentations.DataAugmentation = augmentations.batched_random_crop

  min_replay_size: int = 2_000
  max_replay_size: int = 1_000_000
  replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE
  prefetch_size: int = 1

  discount: float = 0.99
  batch_size: int = 256
  n_step: int = 3

  critic_q_soft_update_rate: float = 0.01
  learning_rate: float = 1e-4
  noise_clip: float = 0.3
  sigma: Tuple[float, float, int] = (1.0, 0.1, 500000)

  samples_per_insert: float = 128
  samples_per_insert_tolerance_rate: float = 0.1

@qstanczyk
Copy link
Collaborator

Then I don't know ;-(

@ethanluoyc
Copy link
Contributor Author

Then I don't know ;-(

Me neither. Thanks a lot for checking this with me. It's really helpful!
I have a feeling that there is some problem with the prefetch thread not being scheduled, so although there are samples the prefetching thread is not putting things onto the queue. I don't see how that can possibly happen and there's no way we can control that given it's done by the OS.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk It seems that the problem may have gone away when I change the default transition adder's max_in_flight_items from 5 to 2.
I am not sure why this helps tho. I am working on dm_control but apply an action repeat of 2 (which means that the episode length is now 500 instead of 1000). and I use n_steps = 3. Maybe this somehow interferes with the rate limiter's tolerance. I never fully understand some of Reverb's parameters (max_in_flight_items) . Maybe it's because of my combination of setup + single-processing that is causing the issues.

@qstanczyk
Copy link
Collaborator

Smaller value of max_in_flight_items reduces the number of elements that can be inserted into Reverb in one actor step. But as rate limiter in case of single-process agent has blocking of inserts disabled, this setting shouldn't affect hangs.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk does the insertion happen in the thread the actor interacts with the environment or is it done in a background thread? In the case writes happen in the agent’s thread it’s possible to deadlock if not all items have been written to reverb right?

I still get the deadlock (from htop all threads are suspended) but I was unable to reproduce this every time. Sometimes it hangs after 10K steps sometimes it does fine even after 1M steps. Is it possible that the problem is caused by the fact that I am using pixel observations (84 x 84 x 9 images).

I have checked over and over again and to me it does seem the prefetching is handled correctly and I don’t see any thing that could go wrong. I wonder if in my case there is some thread starvation going on that’s causing the problem.

@qstanczyk
Copy link
Collaborator

Inserts happen in the background, but at the end of the episode there is a flush (which makes sure all pending items are written to Reverb before continuing). However, with non-distributed agent Reverb's rate limiter should not block inserts.

@ethanluoyc
Copy link
Contributor Author

Thanks for the clarification! Sorry for keeping the thread open as I would really like to figure out the problem. If flush only happens at the end of an episode, does that mean that the agent may deadlock in the middle of an episode as not all writes are processed by reverb?

@qstanczyk
Copy link
Collaborator

That shouldn't happen either - writers don't block (due to rate limiter setup), while sampling should happen only when there is data in the iterator. That is the theory... but seems like there must be an issue somewhere.

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 30, 2022

Yeah. There’s either some subtle thing that’s not handled or I will blame my operating system scheduler for not scheduling the threads that should be processing the items:)

Is there any additional thing that you would like me to check? I wanted to create an minimal example but my agent has several components and I haven’t been able to consistently find the minimal setup that triggers this. I noticed that table.info contains workers stats. Is that useful for you to try to figure out what’s going wrong? I can produce some logs from those stats if they are useful. Otherwise, I suppose we can keep the issue open and maybe I will have a more minimal deterministic example at some point.

there are some discussions in the python-dev which might be related. I don’t yet see how much the issue is similar but just posting here in case it’s relevant. https://bugs.python.org/issue46812. I’m wondering if the behavior of the hangs might be similar to the issue discussed in sqlalchemy/sqlalchemy#7679. Maybe it’s worth mentioning that my experiment is spawning a lot of threads (>500). I have no idea why there are so many of them but I am using a computer with 48 CPUs and 2 3080 GPU (only using one for training)

@qstanczyk
Copy link
Collaborator

A simple repro I could try would be best. Otherwise it is hard to guess what could be wrong.

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 30, 2022

@qstanczyk I manage to create a relatively minimal example and have attached below. I also pasted the code here. It consists of a few files. I have created a zip archive deadlock.zip

You can run the example with MUJOCO_GL=egl XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 python run.py. As I mentioned previously, the blocking doesn't happen deterministically, so it's worthwhile trying a few times. I have seen cases where running for >1M steps works ok, so even reproducing the bug is not that trivial...

drq_v2.py: An implementation of DrQV2 in JAX. I have consolidated the individual components into a single file.

"""Learner component for DrQV2."""
import dataclasses
from functools import partial
import time
from typing import Iterator, List, NamedTuple, Optional, Callable

from acme import adders
from acme import core
from acme import specs
from acme import types
from acme.adders import reverb as adders_reverb
from acme.jax import networks as networks_lib
from acme.jax import types as jax_types
from acme.jax import utils
from acme.jax import variable_utils
from acme import datasets
from acme.agents.jax import builders
from acme.utils import counting
from acme.utils import loggers
from acme.agents.jax import actor_core
from acme.agents.jax import actors
from reverb import rate_limiters
import jax
import jax.numpy as jnp
import optax
import reverb

import networks as drq_v2_networks

DataAugmentation = Callable[[jax_types.PRNGKey, types.NestedArray],
                            types.NestedArray]


# From https://github.com/ikostrikov/jax-rl/blob/main/jax_rl/agents/drq/augmentations.py
def random_crop(key: jax_types.PRNGKey, img, padding):
  crop_from = jax.random.randint(key, (2,), 0, 2 * padding + 1)
  crop_from = jnp.concatenate([crop_from, jnp.zeros((1,), dtype=jnp.int32)])
  padded_img = jnp.pad(
      img, ((padding, padding), (padding, padding), (0, 0)), mode="edge")
  return jax.lax.dynamic_slice(padded_img, crop_from, img.shape)


def batched_random_crop(key, imgs, padding=4):
  keys = jax.random.split(key, imgs.shape[0])
  return jax.vmap(random_crop, (0, 0, None))(keys, imgs, padding)


@dataclasses.dataclass
class DrQV2Config:
  """Configuration parameters for DrQ."""

  augmentation: DataAugmentation = batched_random_crop

  min_replay_size: int = 2_000
  max_replay_size: int = 1_000_000
  replay_table_name: str = adders_reverb.DEFAULT_PRIORITY_TABLE
  prefetch_size: int = 4

  discount: float = 0.99
  batch_size: int = 256
  n_step: int = 3

  critic_q_soft_update_rate: float = 0.01
  learning_rate: float = 1e-4
  noise_clip: float = 0.3
  sigma: float = 0.2

  samples_per_insert: float = 128.0
  samples_per_insert_tolerance_rate: float = 0.1
  num_sgd_steps_per_step: int = 1


def _soft_update(
    target_params: networks_lib.Params,
    online_params: networks_lib.Params,
    tau: float,
) -> networks_lib.Params:
  """
    Update target network using Polyak-Ruppert Averaging.
    """
  return jax.tree_multimap(lambda t, s: (1 - tau) * t + tau * s, target_params,
                           online_params)


class TrainingState(NamedTuple):
  """Holds training state for the DrQ learner."""

  policy_params: networks_lib.Params
  policy_opt_state: optax.OptState

  encoder_params: networks_lib.Params
  # There is not target encoder parameters in v2.
  encoder_opt_state: optax.OptState

  critic_params: networks_lib.Params
  critic_target_params: networks_lib.Params
  critic_opt_state: optax.OptState

  key: jax_types.PRNGKey
  steps: int


class DrQV2Learner(core.Learner):
  """Learner for DrQ-v2"""

  def __init__(
      self,
      random_key: jax_types.PRNGKey,
      dataset: Iterator[reverb.ReplaySample],
      networks: drq_v2_networks.DrQV2Networks,
      sigma_schedule: optax.Schedule,
      augmentation: DataAugmentation,
      policy_optimizer: optax.GradientTransformation,
      critic_optimizer: optax.GradientTransformation,
      encoder_optimizer: optax.GradientTransformation,
      noise_clip: float = 0.3,
      critic_soft_update_rate: float = 0.005,
      discount: float = 0.99,
      num_sgd_steps_per_step: int = 1,
      counter: Optional[counting.Counter] = None,
      logger: Optional[loggers.Logger] = None,
  ):

    def critic_loss_fn(
        critic_params: networks_lib.Params,
        encoder_params: networks_lib.Params,
        critic_target_params: networks_lib.Params,
        policy_params: networks_lib.Params,
        transitions: types.Transition,
        key: jax_types.PRNGKey,
        sigma: jnp.ndarray,
    ):
      next_encoded = networks.encoder_network.apply(
          encoder_params, transitions.next_observation)
      next_action = networks.policy_network.apply(policy_params, next_encoded)
      next_action = networks.add_policy_noise(next_action, key, sigma,
                                              noise_clip)
      next_q1, next_q2 = networks.critic_network.apply(critic_target_params,
                                                       next_encoded,
                                                       next_action)
      # Calculate q target values
      next_q = jnp.minimum(next_q1, next_q2)
      target_q = transitions.reward + transitions.discount * discount * next_q
      target_q = jax.lax.stop_gradient(target_q)
      # Calculate predicted Q
      encoded = networks.encoder_network.apply(encoder_params,
                                               transitions.observation)
      q1, q2 = networks.critic_network.apply(critic_params, encoded,
                                             transitions.action)
      loss_critic = (jnp.square(target_q - q1) +
                     jnp.square(target_q - q2)).mean(axis=0)
      return loss_critic, {"q1": q1.mean(), "q2": q2.mean()}

    def policy_loss_fn(
        policy_params: networks_lib.Params,
        critic_params: networks_lib.Params,
        encoder_params: networks_lib.Params,
        observation: types.Transition,
        sigma: jnp.ndarray,
        key,
    ):
      encoded = networks.encoder_network.apply(encoder_params, observation)
      action = networks.policy_network.apply(policy_params, encoded)
      action = networks.add_policy_noise(action, key, sigma, noise_clip)
      q1, q2 = networks.critic_network.apply(critic_params, encoded, action)
      q = jnp.minimum(q1, q2)
      policy_loss = -q.mean()
      return policy_loss, {}

    policy_grad_fn = jax.value_and_grad(policy_loss_fn, has_aux=True)
    critic_grad_fn = jax.value_and_grad(
        critic_loss_fn, argnums=(0, 1), has_aux=True)

    def update_step(
        state: TrainingState,
        transitions: types.Transition,
    ):
      key_aug1, key_aug2, key_policy, key_critic, key = jax.random.split(
          state.key, 5)
      sigma = sigma_schedule(state.steps)
      # Perform data augmentation on o_tm1 and o_t
      observation_aug = augmentation(key_aug1, transitions.observation)
      next_observation_aug = augmentation(key_aug2,
                                          transitions.next_observation)
      transitions = transitions._replace(
          observation=observation_aug,
          next_observation=next_observation_aug,
      )
      # Update critic
      (critic_loss, critic_aux), (critic_grad, encoder_grad) = critic_grad_fn(
          state.critic_params,
          state.encoder_params,
          state.critic_target_params,
          state.policy_params,
          transitions,
          key_critic,
          sigma,
      )
      encoder_update, encoder_opt_state = encoder_optimizer.update(
          encoder_grad, state.encoder_opt_state)
      critic_update, critic_opt_state = critic_optimizer.update(
          critic_grad, state.critic_opt_state)
      encoder_params = optax.apply_updates(state.encoder_params, encoder_update)
      critic_params = optax.apply_updates(state.critic_params, critic_update)
      # Update policy
      (policy_loss, policy_aux), actor_grad = policy_grad_fn(
          state.policy_params,
          critic_params,
          encoder_params,
          observation_aug,
          sigma,
          key_policy,
      )
      policy_update, policy_opt_state = policy_optimizer.update(
          actor_grad, state.policy_opt_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      # Update target parameters
      polyak_update_fn = partial(_soft_update, tau=critic_soft_update_rate)

      critic_target_params = polyak_update_fn(
          state.critic_target_params,
          critic_params,
      )
      metrics = {
          "policy_loss": policy_loss,
          "critic_loss": critic_loss,
          "sigma": sigma,
          **critic_aux,
          **policy_aux,
      }
      new_state = TrainingState(
          policy_params=policy_params,
          policy_opt_state=policy_opt_state,
          encoder_params=encoder_params,
          encoder_opt_state=encoder_opt_state,
          critic_params=critic_params,
          critic_target_params=critic_target_params,
          critic_opt_state=critic_opt_state,
          key=key,
          steps=state.steps + 1,
      )
      return new_state, metrics

    self._iterator = dataset
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger(
        label="learner",
        save_data=False,
        asynchronous=True,
        serialize_fn=utils.fetch_devicearray,
    )
    self._update_step = utils.process_multiple_batches(update_step,
                                                       num_sgd_steps_per_step)
    self._update_step = jax.jit(self._update_step)

    # Initialize training state
    def make_initial_state(key: jax_types.PRNGKey):
      key_encoder, key_critic, key_policy, key = jax.random.split(key, 4)
      encoder_init_params = networks.encoder_network.init(key_encoder)
      encoder_init_opt_state = encoder_optimizer.init(encoder_init_params)

      critic_init_params = networks.critic_network.init(key_critic)
      critic_init_opt_state = critic_optimizer.init(critic_init_params)

      policy_init_params = networks.policy_network.init(key_policy)
      policy_init_opt_state = policy_optimizer.init(policy_init_params)

      return TrainingState(
          policy_params=policy_init_params,
          policy_opt_state=policy_init_opt_state,
          encoder_params=encoder_init_params,
          critic_params=critic_init_params,
          critic_target_params=critic_init_params,
          encoder_opt_state=encoder_init_opt_state,
          critic_opt_state=critic_init_opt_state,
          key=key,
          steps=0,
      )

    # Create initial state.
    self._state = make_initial_state(random_key)

    # Do not record timestamps until after the first learning step is done.
    # This is to avoid including the time it takes for actors to come online and
    # fill the replay buffer.
    self._timestamp = None

  def step(self):
    # Get the next batch from the replay iterator
    sample = next(self._iterator)
    transitions = types.Transition(*sample.data)

    # Perform a single learner step
    self._state, metrics = self._update_step(self._state, transitions)

    # Compute elapsed time
    timestamp = time.time()
    elapsed_time = timestamp - self._timestamp if self._timestamp else 0
    self._timestamp = timestamp

    # Increment counts and record the current time
    counts = self._counter.increment(steps=1, walltime=elapsed_time)

    # Attempts to write the logs.
    self._logger.write({**metrics, **counts})

  def get_variables(self, names):
    variables = {
        "policy": {
            "encoder": self._state.encoder_params,
            "policy": self._state.policy_params,
        },
    }
    return [variables[name] for name in names]

  def save(self) -> TrainingState:
    return self._state

  def restore(self, state: TrainingState) -> None:
    self._state = state


class DrQV2Builder(builders.ActorLearnerBuilder):
  """DrQ-v2 Builder."""

  def __init__(self, config: DrQV2Config):
    self._config = config

  def make_replay_tables(
      self, environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
    """Create tables to insert data into."""
    samples_per_insert_tolerance = (
        self._config.samples_per_insert_tolerance_rate *
        self._config.samples_per_insert)
    error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
    limiter = rate_limiters.SampleToInsertRatio(
        min_size_to_sample=self._config.min_replay_size,
        samples_per_insert=self._config.samples_per_insert,
        error_buffer=error_buffer,
    )
    replay_table = reverb.Table(
        name=self._config.replay_table_name,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=self._config.max_replay_size,
        rate_limiter=limiter,
        signature=adders_reverb.NStepTransitionAdder.signature(
            environment_spec=environment_spec),
    )
    return [replay_table]

  def make_dataset_iterator(
      self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
    """Create a dataset iterator to use for learning/updating the agent."""
    dataset = datasets.make_reverb_dataset(
        table=self._config.replay_table_name,
        server_address=replay_client.server_address,
        batch_size=self._config.batch_size *
        self._config.num_sgd_steps_per_step,
        prefetch_size=self._config.prefetch_size,
    )
    iterator = dataset.as_numpy_iterator()
    return utils.device_put(iterator, jax.devices()[0])

  def make_adder(self, replay_client: reverb.Client) -> Optional[adders.Adder]:
    """Create an adder which records data generated by the actor/environment.
        Args:
          replay_client: Reverb Client which points to the replay server.
    """
    return adders_reverb.NStepTransitionAdder(
        client=replay_client,
        n_step=self._config.n_step,
        discount=self._config.discount,
    )

  def make_actor(
      self,
      random_key: networks_lib.PRNGKey,
      policy_network: drq_v2_networks.DrQV2PolicyNetwork,
      adder: Optional[adders.Adder] = None,
      variable_source: Optional[core.VariableSource] = None) -> core.Actor:
    """Create an actor instance.
      Args:
        random_key: A key for random number generation.
        policy_network: Instance of a policy network; this should be a callable
          which takes as input observations and returns actions.
        adder: How data is recorded (e.g. added to replay).
        variable_source: A source providing the necessary actor parameters.
    """
    assert variable_source is not None
    variable_client = variable_utils.VariableClient(
        variable_source, "policy", device='cpu')
    variable_client.update_and_wait()
    return actors.GenericActor(
        actor_core.batched_feed_forward_to_actor_core(policy_network),
        random_key=random_key,
        variable_client=variable_client,
        adder=adder,
        backend='cpu')

  def make_learner(self,
                   random_key: networks_lib.PRNGKey,
                   networks: drq_v2_networks.DrQV2Networks,
                   dataset: Iterator[reverb.ReplaySample],
                   logger: Optional[loggers.Logger] = None,
                   replay_client: Optional[reverb.Client] = None,
                   counter: Optional[counting.Counter] = None) -> core.Learner:
    """Creates an instance of the learner.

        Args:
          random_key: A key for random number generation.
          networks: struct describing the networks needed by the learner; this can
            be specific to the learner in question.
          dataset: iterator over samples from replay.
          replay_client: client which allows communication with replay, e.g. in
            order to update priorities.
          counter: a Counter which allows for recording of counts (learner steps,
            actor steps, etc.) distributed throughout the agent.
          checkpoint: bool controlling whether the learner checkpoints itself.
        """
    del replay_client
    config = self._config
    critic_optimizer = optax.adam(config.learning_rate)
    policy_optimizer = optax.adam(config.learning_rate)
    encoder_optimizer = optax.adam(config.learning_rate)

    return DrQV2Learner(
        random_key=random_key,
        dataset=dataset,
        networks=networks,
        sigma_schedule=optax.constant_schedule(config.sigma),
        policy_optimizer=policy_optimizer,
        critic_optimizer=critic_optimizer,
        encoder_optimizer=encoder_optimizer,
        augmentation=config.augmentation,
        critic_soft_update_rate=config.critic_q_soft_update_rate,
        discount=config.discount,
        noise_clip=config.noise_clip,
        num_sgd_steps_per_step=config.num_sgd_steps_per_step,
        logger=logger,
        counter=counter,
    )

networks.py: includes the networks used by the agent as well as policy.

"""Network definitions for DrQ-v2."""
import dataclasses
from typing import Callable, Optional, Union

from acme import specs
from acme import types
from acme.agents.jax import actor_core
from acme.jax import networks as networks_lib
from acme.jax import utils
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import optax

# Unlike standard FF-policy, in our DrQ-V2 implementation we use
# scheduled stddev parameters, the pure function for the policy
# thus needs to know the current time step of the actor to calculate
# the current stddev.
_Step = int
DrQV2PolicyNetwork = Callable[
    [networks_lib.Params, networks_lib.PRNGKey, types.NestedArray, _Step],
    types.NestedArray]


class Encoder(hk.Module):
  """Encoder used by DrQ-v2."""

  def __call__(self, x):
    # Floatify the image.
    x = x.astype(jnp.float32) / 255.0 - 0.5
    conv_kwargs = dict(
        kernel_shape=3,
        output_channels=32,
        padding="VALID",
        # This follows from the reference implementation, the scale accounts for
        # using the ReLU activation.
        w_init=hk.initializers.Orthogonal(jnp.sqrt(2.0)),
    )
    return hk.Sequential([
        hk.Conv2D(stride=2, **conv_kwargs),
        jax.nn.relu,
        hk.Conv2D(stride=1, **conv_kwargs),
        jax.nn.relu,
        hk.Conv2D(stride=1, **conv_kwargs),
        jax.nn.relu,
        hk.Conv2D(stride=1, **conv_kwargs),
        jax.nn.relu,
        hk.Flatten(),
    ])(
        x)


class Actor(hk.Module):
  """Policy network used by DrQ-v2."""

  def __init__(
      self,
      action_size: int,
      latent_size: int = 50,
      hidden_size: int = 1024,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.latent_size = latent_size
    self.action_size = action_size
    self.hidden_size = hidden_size
    w_init = hk.initializers.Orthogonal(1.0)
    self._trunk = hk.Sequential([
        hk.Linear(self.latent_size, w_init=w_init),
        hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
        jnp.tanh,
    ])
    self._head = hk.Sequential([
        hk.Linear(self.hidden_size, w_init=w_init),
        jax.nn.relu,
        hk.Linear(self.hidden_size, w_init=w_init),
        jax.nn.relu,
        hk.Linear(self.action_size, w_init=w_init),
        # tanh is used to squash the actions into the canonical space.
        jnp.tanh,
    ])

  def compute_features(self, inputs):
    return self._trunk(inputs)

  def __call__(self, inputs):
    # Use orthogonal init
    # https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
    h = self.compute_features(inputs)
    mu = self._head(h)
    return mu


class Critic(hk.Module):
  """Single Critic network used by DrQ-v2."""

  def __init__(self, hidden_size: int = 1024, name: Optional[str] = None):
    super().__init__(name)
    self.hidden_size = hidden_size

  def __call__(self, observation, action):
    inputs = jnp.concatenate([observation, action], axis=-1)
    # Use orthogonal init
    # https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
    q_value = hk.nets.MLP(
        output_sizes=(self.hidden_size, self.hidden_size, 1),
        w_init=hk.initializers.Orthogonal(1.0),
        activate_final=False,
    )(inputs).squeeze(-1)
    return q_value


class DoubleCritic(hk.Module):
  """Twin critic network used by DrQ-v2.

    This is simply two identical Critic module.
    """

  def __init__(self, latent_size: int = 50, hidden_size: int = 1024, name=None):
    super().__init__(name)
    self.hidden_size = hidden_size
    self.latent_size = latent_size

    self._trunk = hk.Sequential([
        hk.Linear(self.latent_size, w_init=hk.initializers.Orthogonal(1.0)),
        hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
        jnp.tanh,
    ])
    self._critic1 = Critic(self.hidden_size, name="critic1")
    self._critic2 = Critic(self.hidden_size, name="critic2")

  def compute_features(self, inputs):
    return self._trunk(inputs)

  def __call__(self, observation, action):
    # Use orthogonal init
    # https://github.com/facebookresearch/drqv2/blob/21e9048bf59e15f1018b49b850f727ed7b1e210d/utils.py#L54
    # The trunk is shared between the twin critics
    h = self.compute_features(observation)
    return self._critic1(h, action), self._critic2(h, action)


@dataclasses.dataclass
class DrQV2Networks:
  encoder_network: networks_lib.FeedForwardNetwork
  policy_network: networks_lib.FeedForwardNetwork
  critic_network: networks_lib.FeedForwardNetwork
  add_policy_noise: Callable[
      [types.NestedArray, networks_lib.PRNGKey, float, float],
      types.NestedArray]


def get_default_behavior_policy(
    networks: DrQV2Networks,
    action_specs: specs.BoundedArray,
    sigma,
) -> DrQV2PolicyNetwork:

  def behavior_policy(
      params: networks_lib.Params,
      key: networks_lib.PRNGKey,
      observation: types.NestedArray,
  ):
    feature_map = networks.encoder_network.apply(params["encoder"], observation)
    action = networks.policy_network.apply(params["policy"], feature_map)
    noise = jax.random.normal(key, shape=action.shape) * sigma
    noisy_action = jnp.clip(action + noise, action_specs.minimum,
                            action_specs.maximum)
    return noisy_action

  return behavior_policy


def make_networks(spec: specs.EnvironmentSpec,
                  hidden_size: int = 1024,
                  latent_size: int = 50) -> DrQV2Networks:
  """Create networks for the DrQ-v2 agent."""
  action_size = onp.prod(spec.actions.shape, dtype=int)

  def add_policy_noise(
      action: types.NestedArray,
      key: networks_lib.PRNGKey,
      sigma: float,
      noise_clip: float,
  ) -> types.NestedArray:
    """Adds action noise to bootstrapped Q-value estimate in critic loss."""
    noise = jax.random.normal(key=key, shape=spec.actions.shape) * sigma
    noise = jnp.clip(noise, -noise_clip, noise_clip)
    return jnp.clip(action + noise, spec.actions.minimum, spec.actions.maximum)

  def _critic_fn(x, a):
    return DoubleCritic(
        latent_size=latent_size,
        hidden_size=hidden_size,
    )(x, a)

  def _policy_fn(x):
    return Actor(
        action_size=action_size,
        latent_size=latent_size,
        hidden_size=hidden_size,
    )(x)

  def _encoder_fn(x):
    return Encoder()(x)


  policy = hk.without_apply_rng(hk.transform(_policy_fn, apply_rng=True))
  critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True))
  encoder = hk.without_apply_rng(hk.transform(_encoder_fn, apply_rng=True))
  # policy_feature = hk.without_apply_rng(
  #     hk.transform(_policy_features_fn, apply_rng=True))

  dummy_action = utils.zeros_like(spec.actions)
  dummy_obs = utils.zeros_like(spec.observations)
  dummy_action = utils.add_batch_dim(dummy_action)
  dummy_obs = utils.add_batch_dim(dummy_obs)
  dummy_encoded = hk.testing.transform_and_run(
      _encoder_fn, seed=0, jax_transform=jax.jit)(
          dummy_obs)
  return DrQV2Networks(
      encoder_network=networks_lib.FeedForwardNetwork(
          lambda key: encoder.init(key, dummy_obs), encoder.apply),
      policy_network=networks_lib.FeedForwardNetwork(
          lambda key: policy.init(key, dummy_encoded), policy.apply),
      critic_network=networks_lib.FeedForwardNetwork(
          lambda key: critic.init(key, dummy_encoded, dummy_action),
          critic.apply),
      add_policy_noise=add_policy_noise)

run.py is the training script for running the experiment.

from absl import app
from acme import specs
from acme import wrappers
from acme.jax import experiments
from acme.utils import loggers
from acme.wrappers import mujoco
from dm_control import suite
import dm_env
import drq_v2
import jax
import networks as networks_lib
import tensorflow as tf


def make_experiment_logger(label, steps_key, task_instance=0):
  del task_instance
  return loggers.make_default_logger(
      label, save_data=False, steps_key=steps_key)


def make_environment(domain: str,
                     task: str,
                     seed=None,
                     from_pixels: bool = False,
                     num_action_repeats: int = 1,
                     frames_to_stack: int = 0,
                     camera_id: int = 0) -> dm_env.Environment:
  """Create a dm_control suite environment."""
  environment = suite.load(domain, task, task_kwargs={"random": seed})
  if from_pixels:
    environment = mujoco.MujocoPixelWrapper(environment, camera_id=camera_id)
  else:
    environment = wrappers.ConcatObservationWrapper(environment)
  if num_action_repeats > 1:
    environment = wrappers.ActionRepeatWrapper(environment, num_action_repeats)
  if frames_to_stack > 0:
    assert from_pixels, "frame stack for state not supported"
    environment = wrappers.FrameStackingWrapper(
        environment, frames_to_stack, flatten=True)
  environment = wrappers.CanonicalSpecWrapper(environment, clip=True)
  environment = wrappers.SinglePrecisionWrapper(environment)
  return environment


def main(_):
  tf.config.set_visible_devices([], 'GPU')

  environment_factory = lambda seed: make_environment(
      domain='cheetah',
      task='run',
      seed=seed,
      from_pixels=True,
      num_action_repeats=2,
      frames_to_stack=3,
      camera_id=0)

  num_steps = int(1.5e6)

  environment = environment_factory(0)
  environment_spec = specs.make_environment_spec(environment)
  network_factory = networks_lib.make_networks

  drq_config = drq_v2.DrQV2Config()
  policy_factory = lambda n: networks_lib.get_default_behavior_policy(
      n, environment_spec.actions, drq_config.sigma)
  eval_policy_factory = lambda n: networks_lib.get_default_behavior_policy(
      n, environment_spec.actions, 0.0)

  # Construct the agent.
  builder = drq_v2.DrQV2Builder(drq_config)

  experiment = experiments.Config(
      builder=builder,
      network_factory=network_factory,
      policy_network_factory=policy_factory,
      environment_factory=environment_factory,
      eval_policy_network_factory=eval_policy_factory,
      environment_spec=environment_spec,
      observers=(),
      seed=0,
      logger_factory=make_experiment_logger,
      max_number_of_steps=num_steps)

  experiments.run_experiment(
      experiment, eval_every=int(1e4), num_eval_episodes=5)


if __name__ == '__main__':
  jax.config.config_with_absl()
  app.run(main)

@qstanczyk
Copy link
Collaborator

The problem is with the default value of num_parallel_calls in make_reverb_dataset. By default it creates 12 workers to fetch data, each of them does the batching on its own. Hence even if in Reverb there are batch_size elements available, it can happen that elements will be sampled by different workers, none of the batched will fill up... which results in a hang. I will discuss with the team what to do about make_reverb_dataset implementation. In the meantime you should be good setting num_parallel_calls to 1.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk Yes! I tried setting to smaller values and that seems to help.

I also found tf.data.AUTOTUNE would break things. This is true even when I completely remove the training (e.g., comment off update step) in my learner.

I was just going to post here that I found that num_parallel_calls seem to play a role. I had the suspicion that if somehow num_parallel_calls does not divide the batch size things may not work, but it seems that even setting num_parallel_calls to something like 8 would still result in a deadlock.

Looking forward to the fixes. I will test if num_parallel_calls unblocks my issue. Thanks again for the help, you may have saved me a few more days of trying to figure out what's wrong.

@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented May 30, 2022

OK so now the dead lock behavior is gone I still get the issue mentioned in #235

Attaching the latest stacktrace

  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/environment_loop.py", line 176, in run
    result = self.run_episode()
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/environment_loop.py", line 115, in run_episode
    self._actor.update()
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/jax/layouts/local_layout.py", line 140, in update
    super().update()
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/agents/agent.py", line 105, in update
    self._batch_size_upper_bounds = [
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/acme/agents/agent.py", line 106, in <listcomp>
    math.ceil(t.info.rate_limiter_info.sample_stats.completed /
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/reverb/server.py", line 229, in info
    return reverb_types.TableInfo.from_serialized_proto(proto_string)
  File "/home/yicheng/virtualenvs/ot/lib/python3.8/site-packages/reverb/reverb_types.py", line 80, in from_serialized_proto
    proto = schema_pb2.TableInfo.FromString(proto_string)

BufferError: INVALID_ARGUMENT: Python buffer protocol is only defined for CPU buffers.
[reverb/cc/platform/default/server.cc:84] Shutting down replay server
E0531 00:47:26.028071 139694191777600 base.py:130] Timeout (10000 ms) exceeded when flushing the writer before deleting it. Caught Reverb exception: Flush call did not complete within provided timeout of 0:00:10

My protobuf version is 3.20.1

@tomdps
Copy link

tomdps commented Jun 1, 2022

Hi,

I also experienced deadlocks while training a custom pixel-based agent, but this issue has been occurring both in Local and Distributed Layout, and before the recent changes in implementation. One thing curious thing I noticed is that the random deadlock only happens when using multiple GPUs, but not with a single GPU, and seems to be related to pmap itself.

I mentioned this problem here: google/jax#10763
ANd there seem to be other people encountering the issue on other JAX codes as well. One of JAX's maintainer managed to reproduce this deadlock of the pmaped function with someone's code.

Maybe the pmap JAX bug on GPUs is actually the underlying problem here too ?

@qstanczyk
Copy link
Collaborator

this change makes non-distributed setup equivalent to the distributed setup from the "deadlock" perspective. As long as rate limiters are configured properly both setups should work fine.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk that looks good! Maybe this change should also propagate to LocalLayout?

@qstanczyk
Copy link
Collaborator

LocalLayout should go away soon, it is being replaced with run_experiment.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk There are sometimes good reasons for creating custom training loops and it would be great if some way of creating a single process agent from builder is still available somewhere for users to implement single process agents. Initially I thought local layout should be the way to go and alternative ways are also ok

@qstanczyk
Copy link
Collaborator

It is still possible to build a custom training loop for a single process agent by cloning run_experiment. We try to move away from the Layout design to make it easier to understand the agent's code for new users.

@ethanluoyc
Copy link
Contributor Author

Right. That makes sense to me. Still I think it can be a lot of duplication if forking run_experiment would require keeping a copy of something like _TrainingAdder in the forked copy. For me, I probably want to customize the training loop, but not necessarily something like _TrainingAdder

@ethanluoyc
Copy link
Contributor Author

By the way, is the future plan for writing tests for agents through run_experiment? I saw the old agent test files removed but I still think there is a lot of value of having those instead of just having the examples. For my code, I use the tests extensively to guard against API changes in Acme and also for my refactoring.

@qstanczyk
Copy link
Collaborator

In that case it should be fine to clone the loop itself, while referencing the original _TrainingAdder. We wanted to keep entire logic of the run_experiment (including _TrainingAdder) in a single file for the convenience of the reader. The other reason is that implementation of run_experiment might change, so _TrainingAdder could be modified / removed. But if you work with a fixed version of Acme or are fine with merging changes, it is ok to reference _TrainingAdder.

@qstanczyk
Copy link
Collaborator

For the tests where logic of the test was equivalent to the logic of examples we try to deduplicate the code. But in some cases I believe it makes sense to have a standalone test too.

@ethanluoyc
Copy link
Contributor Author

@qstanczyk Thanks for the clarification. Yeah the old tests are sometimes quite duplicated indeed. I was hoping some testing utilities exists that can help with testing an entire agent right from builder.

I don’t work with a fixed version of Acme and instead work against HEAD. There are a lot of very useful things that get added and I keep my code updated to follow those. I guess for now it’s indeed a good idea to just fork run_experiment. I do that any way since I need to set up other additional things. I will merge in changes if the mechanism for preventing blocking changes in the single process case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants