Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Pull requests that adds more logging + ability to push to more
Browse files Browse the repository at this point in the history
downstream models
  • Loading branch information
EntilZha committed Apr 28, 2022
1 parent 7414bff commit 87601f4
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 12 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ numpy
opencv-python
tabulate
torch>=1.5.1
rich
50 changes: 42 additions & 8 deletions rlmeta/agents/dqn/apex_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.nn as nn
from rich.console import Console
from rich.progress import track

import rlmeta.utils.data_utils as data_utils
import rlmeta_extension.nested_utils as nested_utils
Expand All @@ -21,6 +23,8 @@
from rlmeta.core.types import NestedTensor
from rlmeta.utils.stats_dict import StatsDict

console = Console()


class ApeXDQNAgent(Agent):

Expand All @@ -39,7 +43,8 @@ def __init__(
sync_every_n_steps: int = 10,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
NestedTensor]] = None
NestedTensor]] = None,
additional_models_to_update: Optional[List[ModelLike]] = None,
) -> None:
super().__init__()

Expand All @@ -57,6 +62,7 @@ def __init__(
self.gamma = gamma
self.learning_starts = learning_starts

self._additional_models_to_update = additional_models_to_update
self.sync_every_n_steps = sync_every_n_steps
self.push_every_n_steps = push_every_n_steps

Expand All @@ -67,6 +73,12 @@ def __init__(

self.trajectory = []

def connect(self) -> None:
super().connect()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.connect()

def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action = self.model.act(obs, torch.tensor([self.eps]))
Expand Down Expand Up @@ -113,11 +125,15 @@ async def async_update(self) -> None:
self.trajectory = []

def train(self, num_steps: int) -> Optional[StatsDict]:
console.log(f"Training for num_steps={num_steps}")
self.controller.set_phase(Phase.TRAIN, reset=True)

console.log(f"Warming up replay buffer: {self.replay_buffer}")

self.replay_buffer.warm_up(self.learning_starts)
console.log("Replay buffer warmed up!")
stats = StatsDict()
for step in range(num_steps):
for step in track(range(num_steps), description="Training..."):
t0 = time.perf_counter()
batch, weight, index = self.replay_buffer.sample(self.batch_size)
t1 = time.perf_counter()
Expand All @@ -132,9 +148,15 @@ def train(self, num_steps: int) -> Optional[StatsDict]:

if step % self.sync_every_n_steps == self.sync_every_n_steps - 1:
self.model.sync_target_net()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.sync_target_net()

if step % self.push_every_n_steps == self.push_every_n_steps - 1:
self.model.push()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.push()

episode_stats = self.controller.get_stats()
stats.update(episode_stats)
Expand Down Expand Up @@ -218,7 +240,8 @@ def __init__(
sync_every_n_steps: int = 10,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
NestedTensor]] = None
NestedTensor]] = None,
additional_models_to_update: Optional[List[ModelLike]] = None,
) -> None:
self._model = model
self._eps_func = eps_func
Expand All @@ -233,17 +256,28 @@ def __init__(
self._sync_every_n_steps = sync_every_n_steps
self._push_every_n_steps = push_every_n_steps
self._collate_fn = collate_fn
self._additional_models_to_update = additional_models_to_update

def __call__(self, index: int):
model = self._make_arg(self._model, index)
eps = self._eps_func(index)
replay_buffer = self._make_arg(self._replay_buffer, index)
controller = self._make_arg(self._controller, index)
return ApeXDQNAgent(model, eps, replay_buffer, controller,
self._optimizer, self._batch_size, self._grad_clip,
self._multi_step, self._gamma,
self._learning_starts, self._sync_every_n_steps,
self._push_every_n_steps, self._collate_fn)
return ApeXDQNAgent(
model,
eps,
replay_buffer,
controller,
self._optimizer,
self._batch_size,
self._grad_clip,
self._multi_step,
self._gamma,
self._learning_starts,
self._sync_every_n_steps,
self._push_every_n_steps,
self._collate_fn,
additional_models_to_update=self._additional_models_to_update)


class ConstantEpsFunc:
Expand Down
3 changes: 3 additions & 0 deletions rlmeta/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self) -> None:
self._limit = None
self._stats = StatsDict()

def __repr__(self):
return f'Controller(phase={self._phase})'

@property
def phase(self) -> Phase:
return self._phase
Expand Down
4 changes: 4 additions & 0 deletions rlmeta/core/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time

from typing import Dict, List, NoReturn, Optional, Sequence, Union
from rich.console import Console

import torch
import torch.multiprocessing as mp
Expand All @@ -26,6 +27,8 @@
from rlmeta.core.launchable import Launchable
from rlmeta.envs.env import Env, EnvFactory

console = Console()


class Loop(abc.ABC):

Expand Down Expand Up @@ -128,6 +131,7 @@ def init_execution(self) -> None:
obj.init_execution()

def run(self) -> NoReturn:
console.log(f"Starting async loop with: {self._controller}")
self._loop = asyncio.get_event_loop()
self._tasks.append(
asycio_utils.create_task(self._loop, self._check_phase()))
Expand Down
7 changes: 7 additions & 0 deletions rlmeta/core/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __init__(self,
server_addr: str,
name: Optional[str] = None,
timeout: float = 60) -> None:
self._target = target
self._server_name = server_name
self._server_addr = server_addr

self._remote_methods = target.remote_methods
self._reset(server_name, server_addr, name, timeout)
self._client_methods = {}
Expand All @@ -58,6 +62,9 @@ def __getattribute__(self, attr: str) -> Any:
return ret
raise

def __repr__(self):
return f'Remote(target={self._target} server_name={self._server_name} server_addr={self._server_addr})'

@property
def name(self) -> str:
return self._name
Expand Down
13 changes: 10 additions & 3 deletions rlmeta/core/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging

from typing import Callable, Optional, Sequence, Tuple, Union

from rich.console import Console
import numpy as np
import torch

Expand All @@ -22,6 +22,8 @@
from rlmeta.core.types import Tensor, NestedTensor
from rlmeta_extension import CircularBuffer

console = Console()


class ReplayBuffer(remote.Remotable, Launchable):

Expand Down Expand Up @@ -260,6 +262,11 @@ def __init__(self,
super().__init__(target, server_name, server_addr, name, timeout)
self._prefetch = prefetch
self._futures = collections.deque()
self._server_name = server_name
self._server_addr = server_addr

def __repr__(self):
return f'RemoteReplayBuffer(server_name={self._server_name}, server_addr={self._server_addr})'

@property
def prefetch(self) -> Optional[int]:
Expand Down Expand Up @@ -304,8 +311,8 @@ def warm_up(self, learning_starts: Optional[int] = None) -> None:
while cur_size < target_size:
time.sleep(1)
cur_size = self.get_size()
logging.info("Warming up replay buffer: " +
f"[{cur_size: {width}d} / {capacity} ]")
console.log("Warming up replay buffer: " +
f"[{cur_size: {width}d} / {capacity} ]")


ReplayBufferLike = Union[ReplayBuffer, RemoteReplayBuffer]
Expand Down
15 changes: 14 additions & 1 deletion rlmeta/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.multiprocessing as mp
from rich.console import Console

import moolib

Expand All @@ -20,6 +21,8 @@
from rlmeta.core.launchable import Launchable
from rlmeta.core.remote import Remotable

console = Console()


class Server(Launchable):

Expand All @@ -34,6 +37,9 @@ def __init__(self, name: str, addr: str, timeout: float = 60) -> None:
self._loop = None
self._tasks = None

def __repr__(self):
return f'Server(name={self._name} addr={self._addr})'

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -82,11 +88,17 @@ def init_execution(self) -> None:
self._server = moolib.Rpc()
self._server.set_name(self._name)
self._server.set_timeout(self._timeout)
self._server.listen(self._addr)
console.log(f"Server={self.name} listening to {self._addr}")
try:
self._server.listen(self._addr)
except:
console.log(f"ERROR on listen({self._addr}) from: server={self}")
raise

def _start_services(self) -> NoReturn:
self._loop = asyncio.get_event_loop()
self._tasks = []
console.log(f"Server={self.name} starting services: {self._services}")
for service in self._services:
for method in service.remote_methods:
method_impl = getattr(service, method)
Expand All @@ -103,6 +115,7 @@ def _start_services(self) -> NoReturn:
task.cancel()
self._loop.stop()
self._loop.close()
console.log(f"Server={self.name} services started")

def _add_server_task(self, func_name: str, func_impl: Callable[..., Any],
batch_size: Optional[int]) -> None:
Expand Down

0 comments on commit 87601f4

Please sign in to comment.