Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add more logging + ability to push to more downstream models #40

Merged
merged 1 commit into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ numpy
opencv-python
tabulate
torch>=1.5.1
rich
50 changes: 42 additions & 8 deletions rlmeta/agents/dqn/apex_dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch
import torch.nn as nn
from rich.console import Console
from rich.progress import track

import rlmeta.utils.data_utils as data_utils
import rlmeta_extension.nested_utils as nested_utils
Expand All @@ -21,6 +23,8 @@
from rlmeta.core.types import NestedTensor
from rlmeta.utils.stats_dict import StatsDict

console = Console()


class ApeXDQNAgent(Agent):

Expand All @@ -39,7 +43,8 @@ def __init__(
sync_every_n_steps: int = 10,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
NestedTensor]] = None
NestedTensor]] = None,
additional_models_to_update: Optional[List[ModelLike]] = None,
) -> None:
super().__init__()

Expand All @@ -57,6 +62,7 @@ def __init__(
self.gamma = gamma
self.learning_starts = learning_starts

self._additional_models_to_update = additional_models_to_update
self.sync_every_n_steps = sync_every_n_steps
self.push_every_n_steps = push_every_n_steps

Expand All @@ -67,6 +73,12 @@ def __init__(

self.trajectory = []

def connect(self) -> None:
super().connect()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.connect()

def act(self, timestep: TimeStep) -> Action:
obs = timestep.observation
action = self.model.act(obs, torch.tensor([self.eps]))
Expand Down Expand Up @@ -113,11 +125,15 @@ async def async_update(self) -> None:
self.trajectory = []

def train(self, num_steps: int) -> Optional[StatsDict]:
console.log(f"Training for num_steps={num_steps}")
self.controller.set_phase(Phase.TRAIN, reset=True)

console.log(f"Warming up replay buffer: {self.replay_buffer}")

self.replay_buffer.warm_up(self.learning_starts)
console.log("Replay buffer warmed up!")
stats = StatsDict()
for step in range(num_steps):
for step in track(range(num_steps), description="Training..."):
t0 = time.perf_counter()
batch, weight, index = self.replay_buffer.sample(self.batch_size)
t1 = time.perf_counter()
Expand All @@ -132,9 +148,15 @@ def train(self, num_steps: int) -> Optional[StatsDict]:

if step % self.sync_every_n_steps == self.sync_every_n_steps - 1:
self.model.sync_target_net()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.sync_target_net()

if step % self.push_every_n_steps == self.push_every_n_steps - 1:
self.model.push()
if self._additional_models_to_update is not None:
for m in self._additional_models_to_update:
m.push()

episode_stats = self.controller.get_stats()
stats.update(episode_stats)
Expand Down Expand Up @@ -218,7 +240,8 @@ def __init__(
sync_every_n_steps: int = 10,
push_every_n_steps: int = 1,
collate_fn: Optional[Callable[[Sequence[NestedTensor]],
NestedTensor]] = None
NestedTensor]] = None,
additional_models_to_update: Optional[List[ModelLike]] = None,
) -> None:
self._model = model
self._eps_func = eps_func
Expand All @@ -233,17 +256,28 @@ def __init__(
self._sync_every_n_steps = sync_every_n_steps
self._push_every_n_steps = push_every_n_steps
self._collate_fn = collate_fn
self._additional_models_to_update = additional_models_to_update

def __call__(self, index: int):
model = self._make_arg(self._model, index)
eps = self._eps_func(index)
replay_buffer = self._make_arg(self._replay_buffer, index)
controller = self._make_arg(self._controller, index)
return ApeXDQNAgent(model, eps, replay_buffer, controller,
self._optimizer, self._batch_size, self._grad_clip,
self._multi_step, self._gamma,
self._learning_starts, self._sync_every_n_steps,
self._push_every_n_steps, self._collate_fn)
return ApeXDQNAgent(
model,
eps,
replay_buffer,
controller,
self._optimizer,
self._batch_size,
self._grad_clip,
self._multi_step,
self._gamma,
self._learning_starts,
self._sync_every_n_steps,
self._push_every_n_steps,
self._collate_fn,
additional_models_to_update=self._additional_models_to_update)


class ConstantEpsFunc:
Expand Down
3 changes: 3 additions & 0 deletions rlmeta/core/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self) -> None:
self._limit = None
self._stats = StatsDict()

def __repr__(self):
return f'Controller(phase={self._phase})'

@property
def phase(self) -> Phase:
return self._phase
Expand Down
4 changes: 4 additions & 0 deletions rlmeta/core/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time

from typing import Dict, List, NoReturn, Optional, Sequence, Union
from rich.console import Console

import torch
import torch.multiprocessing as mp
Expand All @@ -26,6 +27,8 @@
from rlmeta.core.launchable import Launchable
from rlmeta.envs.env import Env, EnvFactory

console = Console()


class Loop(abc.ABC):

Expand Down Expand Up @@ -128,6 +131,7 @@ def init_execution(self) -> None:
obj.init_execution()

def run(self) -> NoReturn:
console.log(f"Starting async loop with: {self._controller}")
self._loop = asyncio.get_event_loop()
self._tasks.append(
asycio_utils.create_task(self._loop, self._check_phase()))
Expand Down
7 changes: 7 additions & 0 deletions rlmeta/core/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __init__(self,
server_addr: str,
name: Optional[str] = None,
timeout: float = 60) -> None:
self._target = target
self._server_name = server_name
self._server_addr = server_addr

self._remote_methods = target.remote_methods
self._reset(server_name, server_addr, name, timeout)
self._client_methods = {}
Expand All @@ -58,6 +62,9 @@ def __getattribute__(self, attr: str) -> Any:
return ret
raise

def __repr__(self):
return f'Remote(target={self._target} server_name={self._server_name} server_addr={self._server_addr})'

@property
def name(self) -> str:
return self._name
Expand Down
13 changes: 10 additions & 3 deletions rlmeta/core/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging

from typing import Callable, Optional, Sequence, Tuple, Union

from rich.console import Console
import numpy as np
import torch

Expand All @@ -22,6 +22,8 @@
from rlmeta.core.types import Tensor, NestedTensor
from rlmeta_extension import CircularBuffer

console = Console()


class ReplayBuffer(remote.Remotable, Launchable):

Expand Down Expand Up @@ -260,6 +262,11 @@ def __init__(self,
super().__init__(target, server_name, server_addr, name, timeout)
self._prefetch = prefetch
self._futures = collections.deque()
self._server_name = server_name
self._server_addr = server_addr

def __repr__(self):
return f'RemoteReplayBuffer(server_name={self._server_name}, server_addr={self._server_addr})'

@property
def prefetch(self) -> Optional[int]:
Expand Down Expand Up @@ -304,8 +311,8 @@ def warm_up(self, learning_starts: Optional[int] = None) -> None:
while cur_size < target_size:
time.sleep(1)
cur_size = self.get_size()
logging.info("Warming up replay buffer: " +
f"[{cur_size: {width}d} / {capacity} ]")
console.log("Warming up replay buffer: " +
f"[{cur_size: {width}d} / {capacity} ]")


ReplayBufferLike = Union[ReplayBuffer, RemoteReplayBuffer]
Expand Down
15 changes: 14 additions & 1 deletion rlmeta/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.multiprocessing as mp
from rich.console import Console

import moolib

Expand All @@ -20,6 +21,8 @@
from rlmeta.core.launchable import Launchable
from rlmeta.core.remote import Remotable

console = Console()


class Server(Launchable):

Expand All @@ -34,6 +37,9 @@ def __init__(self, name: str, addr: str, timeout: float = 60) -> None:
self._loop = None
self._tasks = None

def __repr__(self):
return f'Server(name={self._name} addr={self._addr})'

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -82,11 +88,17 @@ def init_execution(self) -> None:
self._server = moolib.Rpc()
self._server.set_name(self._name)
self._server.set_timeout(self._timeout)
self._server.listen(self._addr)
console.log(f"Server={self.name} listening to {self._addr}")
try:
self._server.listen(self._addr)
except:
console.log(f"ERROR on listen({self._addr}) from: server={self}")
raise

def _start_services(self) -> NoReturn:
self._loop = asyncio.get_event_loop()
self._tasks = []
console.log(f"Server={self.name} starting services: {self._services}")
for service in self._services:
for method in service.remote_methods:
method_impl = getattr(service, method)
Expand All @@ -103,6 +115,7 @@ def _start_services(self) -> NoReturn:
task.cancel()
self._loop.stop()
self._loop.close()
console.log(f"Server={self.name} services started")

def _add_server_task(self, func_name: str, func_impl: Callable[..., Any],
batch_size: Optional[int]) -> None:
Expand Down