Skip to content

Commit

Permalink
Change default config to dqn.
Browse files Browse the repository at this point in the history
  • Loading branch information
cyoon1729 committed Jun 22, 2020
1 parent 1ab5e35 commit b954f4d
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 54 deletions.
2 changes: 1 addition & 1 deletion configs/pong_no_frameskip_v4/apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.1, # openai baselines: 0.01
epsilon_decay=1e-7, # openai baselines: 1e-7 / 1e-1
epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1
# grad_cam
grad_cam_layer_list=[
"backbone.cnn.cnn_0.cnn",
Expand Down
3 changes: 0 additions & 3 deletions rl_algorithms/common/distributed/abstract/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def recv_log_info(self):

def run(self):
"""Run main logging loop; continuously receive data and log"""
# logger
if self.args.log:
self.set_wandb()

Expand All @@ -141,7 +140,6 @@ def run(self):
self.synchronize(state_dict)
avg_score = self.test(self.update_step)
log_value["avg_score"] = avg_score

self.write_log(log_value)

def test(self, update_step: int, interim_test: bool = True):
Expand All @@ -150,7 +148,6 @@ def test(self, update_step: int, interim_test: bool = True):

# termination
self.env.close()

return avg_score

def _test(self, update_step: int, interim_test: bool) -> float:
Expand Down
18 changes: 9 additions & 9 deletions rl_algorithms/common/distributed/abstract/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class BaseWorker(ABC):
"""Base class for Worker classes"""
"""Base class for Worker classes."""

@abstractmethod
def select_action(self, state: np.ndarray) -> np.ndarray:
Expand All @@ -37,14 +37,14 @@ def synchronize(self, new_params: list):

# pylint: disable=no-self-use
def _synchronize(self, network, new_params: List[np.ndarray]):
"""Copy parameters from numpy arrays"""
"""Copy parameters from numpy arrays."""
for param, new_param in zip(network.parameters(), new_params):
new_param = torch.FloatTensor(new_param).to(self.device)
param.data.copy_(new_param)


class Worker(BaseWorker):
"""Base class for all functioning RL workers
"""Base class for all functioning RL workers.
Attributes:
rank (int): rank (ID) of worker
Expand All @@ -63,7 +63,7 @@ def __init__(
hyper_params: ConfigDict,
device: str,
):
"""Initialize"""
"""Initialize."""
self.rank = rank
self.args = args
self.env_info = env_info
Expand All @@ -74,7 +74,7 @@ def __init__(

# pylint: disable=attribute-defined-outside-init, no-self-use
def _init_env(self):
"""Intialize worker local environment"""
"""Intialize worker local environment."""
if self.env_info.is_atari:
self.env = atari_env_generator(
self.env_info.name, self.args.max_episode_steps, frame_stack=True
Expand Down Expand Up @@ -119,7 +119,7 @@ def _preprocess_state(state: np.ndarray, device: torch.device) -> torch.Tensor:


class DistributedWorkerWrapper(BaseWorker):
"""Base wrapper class for distributed worker wrappers"""
"""Base wrapper class for distributed worker wrappers."""

def __init__(self, worker: Worker, args: argparse.Namespace, comm_cfg: ConfigDict):
self.worker = worker
Expand All @@ -135,11 +135,11 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
return self.worker.select_action(state)

def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]:
"""Take an action and return the response of the env"""
"""Take an action and return the response of the env."""
return self.worker.step(action)

def synchronize(self, new_params: list):
"""Synchronize worker brain with learner brain"""
"""Synchronize worker brain with learner brain."""
self.worker.synchronize(new_params)

@abstractmethod
Expand All @@ -151,7 +151,7 @@ def run(self):
pass

def preprocess_nstep(self, nstepqueue: Deque) -> Tuple[np.ndarray, ...]:
"""Return n-step transition with discounted reward"""
"""Return n-step transition with discounted reward."""
discounted_reward = 0
_, _, _, last_state, done = nstepqueue[-1]
for transition in list(reversed(nstepqueue)):
Expand Down
10 changes: 5 additions & 5 deletions rl_algorithms/common/distributed/apex/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@AGENTS.register_module
class ApeX(Architecture):
"""General Ape-X architecture for distributed training
"""General Ape-X architecture for distributed training.
Attributes:
rank (int): rank (ID) of worker
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(

# pylint: disable=attribute-defined-outside-init
def _organize_configs(self):
"""Organize configs for initializing components from registry"""
"""Organize configs for initializing components from registry."""
# organize learner configs
self.learner_cfg.args = self.args
self.learner_cfg.env_info = self.env_info
Expand All @@ -89,7 +89,7 @@ def _organize_configs(self):
self.logger_cfg.head = self.learner_cfg.head

def _spawn(self):
"""Intialize distributed worker, learner and centralized replay buffer"""
"""Intialize distributed worker, learner and centralized replay buffer."""
replay_buffer = ReplayBuffer(
self.hyper_params.buffer_size, self.hyper_params.batch_size,
)
Expand Down Expand Up @@ -119,7 +119,7 @@ def _spawn(self):
self.processes = self.workers + [self.learner, self.global_buffer, self.logger]

def train(self):
"""Spawn processes and run training loop"""
"""Spawn processes and run training loop."""
print("Spawning and initializing communication...")
# Spawn processes:
self._spawn()
Expand All @@ -134,7 +134,7 @@ def train(self):
print("Exiting training...")

def test(self):
"""Load model from checkpoint and run logger for testing"""
"""Load model from checkpoint and run logger for testing."""
# NOTE: You could also load the Ape-X trained model on the single agent DQN
self.logger = build_logger(self.logger_cfg)
self.logger.load_params.remote(self.args.load_from)
Expand Down
6 changes: 3 additions & 3 deletions rl_algorithms/common/distributed/apex/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@ray.remote
class ApeXBufferWrapper(BufferWrapper):
"""Wrapper for Ape-X global buffer
"""Wrapper for Ape-X global buffer.
Attributes:
per_buffer (ReplayBuffer): prioritized replay buffer
Expand All @@ -43,7 +43,7 @@ def __init__(

# pylint: disable=attribute-defined-outside-init
def init_communication(self):
"""Initialize sockets for communication"""
"""Initialize sockets for communication."""
ctx = zmq.Context()
self.req_socket = ctx.socket(zmq.REQ)
self.req_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_buffer_port}")
Expand Down Expand Up @@ -75,7 +75,7 @@ def recv_worker_data(self):
self.buffer.update_priorities([len(self.buffer) - 1], priorities[idx])

def send_batch_to_learner(self):
"""Send batch to learner and receive priorities"""
"""Send batch to learner and receive priorities."""
# Send batch and request priorities (blocking recv)
batch = self.buffer.sample(self.per_beta)
batch_id = pa.serialize(batch).to_buffer()
Expand Down
14 changes: 7 additions & 7 deletions rl_algorithms/common/distributed/apex/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@ray.remote(num_gpus=1)
class ApeXLearnerWrapper(DistributedLearnerWrapper):
"""Learner Wrapper to enable Ape-X distributed training
"""Learner Wrapper to enable Ape-X distributed training.
Attributes:
learner (Learner): learner
Expand All @@ -48,7 +48,7 @@ def __init__(self, learner: Learner, comm_cfg: ConfigDict):

# pylint: disable=attribute-defined-outside-init
def init_communication(self):
"""Initialize sockets for communication"""
"""Initialize sockets for communication."""
ctx = zmq.Context()
# Socket to send updated network parameters to worker
self.pub_socket = ctx.socket(zmq.PUB)
Expand All @@ -63,34 +63,34 @@ def init_communication(self):
self.push_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}")

def recv_replay_data(self):
"""Receive replay data from gloal buffer"""
"""Receive replay data from gloal buffer."""
replay_data_id = self.rep_socket.recv()
replay_data = pa.deserialize(replay_data_id)
return replay_data

def send_new_priorities(self, indices: np.ndarray, priorities: np.ndarray):
"""Send new priority values and corresponding indices to buffer"""
"""Send new priority values and corresponding indices to buffer."""
new_priors = [indices, priorities]
new_priors_id = pa.serialize(new_priors).to_buffer()
self.rep_socket.send(new_priors_id)

def publish_params(self, update_step: int, np_state_dict: List[np.ndarray]):
"""Broadcast updated params to all workers"""
"""Broadcast updated params to all workers."""
param_info = [update_step, np_state_dict]
new_params_id = pa.serialize(param_info).to_buffer()
self.pub_socket.send(new_params_id)

def send_info_to_logger(
self, np_state_dict: List[np.ndarray], step_info: list,
):
"""Send new params and log info to logger"""
"""Send new params and log info to logger."""
log_value = dict(update_step=self.update_step, step_info=step_info)
log_info = dict(log_value=log_value, state_dict=np_state_dict)
log_info_id = pa.serialize(log_info).to_buffer()
self.push_socket.send(log_info_id)

def run(self):
"""Run main training loop"""
"""Run main training loop."""
self.telapsed = 0
while self.update_step < self.max_update_step:
replay_data = self.recv_replay_data()
Expand Down
14 changes: 7 additions & 7 deletions rl_algorithms/common/distributed/apex/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

@ray.remote(num_cpus=1)
class ApeXWorkerWrapper(DistributedWorkerWrapper):
"""Wrapper class for ApeX based distributed workers
"""Wrapper class for ApeX based distributed workers.
Attributes:
hyper_params (ConfigDict): worker hyper_params
Expand All @@ -43,7 +43,7 @@ def __init__(self, worker: Worker, args: argparse.Namespace, comm_cfg: ConfigDic

# pylint: disable=attribute-defined-outside-init
def init_communication(self):
"""Initialize sockets connecting worker-learner, worker-buffer"""
"""Initialize sockets connecting worker-learner, worker-buffer."""
# for receiving params from learner
ctx = zmq.Context()
self.sub_socket = ctx.socket(zmq.SUB)
Expand All @@ -56,12 +56,12 @@ def init_communication(self):
self.push_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.worker_buffer_port}")

def send_data_to_buffer(self, replay_data):
"""Send replay data to global buffer"""
"""Send replay data to global buffer."""
replay_data_id = pa.serialize(replay_data).to_buffer()
self.push_socket.send(replay_data_id)

def recv_params_from_learner(self):
"""Get new params and sync. return True if success, False otherwise"""
"""Get new params and sync. return True if success, False otherwise."""
received = False
try:
new_params_id = self.sub_socket.recv(zmq.DONTWAIT)
Expand All @@ -76,11 +76,11 @@ def recv_params_from_learner(self):
self.worker.synchronize(new_params)

def compute_priorities(self, experience: Dict[str, np.ndarray]):
"""Compute priority values (TD error) of collected experience"""
"""Compute priority values (TD error) of collected experience."""
return self.worker.compute_priorities(experience)

def collect_data(self) -> dict:
"""Fill and return local buffer"""
"""Fill and return local buffer."""
local_memory = [0]
local_memory = dict(states=[], actions=[], rewards=[], next_states=[], dones=[])
local_memory_keys = local_memory.keys()
Expand Down Expand Up @@ -126,7 +126,7 @@ def collect_data(self) -> dict:
return local_memory

def run(self):
"""Run main worker loop"""
"""Run main worker loop."""
while self.update_step < self.args.max_update_step:
experience = self.collect_data()
priority_values = self.compute_priorities(experience)
Expand Down
6 changes: 3 additions & 3 deletions rl_algorithms/dqn/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _init_network(self):
def update_model(
self, experience: Union[TensorTuple, Tuple[TensorTuple]]
) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]: # type: ignore
"""Update dqn and dqn target"""
"""Update dqn and dqn target."""

if self.use_n_step:
experience_1, experience_n = experience
Expand Down Expand Up @@ -160,10 +160,10 @@ def load_params(self, path: str):
print("[INFO] loaded the model and optimizer from", path)

def get_state_dict(self) -> OrderedDict:
"""Return state dicts, mainly for distributed worker"""
"""Return state dicts, mainly for distributed worker."""
dqn = deepcopy(self.dqn)
return dqn.cpu().state_dict()

def get_policy(self) -> nn.Module:
"""Return model (policy) used for action selection, used only in grad cam"""
"""Return model (policy) used for action selection, used only in grad cam."""
return self.dqn
13 changes: 7 additions & 6 deletions rl_algorithms/dqn/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@LOGGERS.register_module
class DQNLogger(Logger):
"""DQN Logger for distributed training"""
"""DQN Logger for distributed training."""

def __init__(
self,
Expand All @@ -40,15 +40,16 @@ def load_params(self, path: str):
print("[INFO] loaded the model and optimizer from", path)

def select_action(self, state: np.ndarray):
"""Select action to be executed at given state"""
state = self._preprocess_state(state, self.device)
selected_action = self.brain(state).argmax()
selected_action = selected_action.detach().cpu().numpy()
"""Select action to be executed at given state."""
with torch.no_grad():
state = self._preprocess_state(state, self.device)
selected_action = self.brain(state).argmax()
selected_action = selected_action.cpu().numpy()

return selected_action

def write_log(self, log_value: dict):
"""Write log about loss and score"""
"""Write log about loss and score."""
print(
"[INFO] update_step %d, average score: %f, "
"loss: %f, avg q-value: %f"
Expand Down
15 changes: 8 additions & 7 deletions rl_algorithms/dqn/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(

# pylint: disable=attribute-defined-outside-init
def _init_networks(self, state_dict: OrderedDict):
"""Initialize DQN policy with learner state dict"""
"""Initialize DQN policy with learner state dict."""
self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
self.dqn.load_state_dict(state_dict)

Expand All @@ -76,9 +76,10 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
if self.epsilon > np.random.random():
selected_action = np.array(self.env.action_space.sample())
else:
state = self._preprocess_state(state, self.device)
selected_action = self.dqn(state).argmax()
selected_action = selected_action.detach().cpu().numpy()
with torch.no_grad():
state = self._preprocess_state(state, self.device)
selected_action = self.dqn(state).argmax()
selected_action = selected_action.cpu().numpy()

# Decay epsilon
self.epsilon = max(
Expand All @@ -90,12 +91,12 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
return selected_action

def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]:
"""Take an action and return the response of the env"""
"""Take an action and return the response of the env."""
next_state, reward, done, info = self.env.step(action)
return next_state, reward, done, info

def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray:
"""Compute initial priority values of experiences in local memory"""
"""Compute initial priority values of experiences in local memory."""
states = torch.FloatTensor(memory["states"]).to(self.device)
actions = torch.LongTensor(memory["actions"]).to(self.device)
rewards = torch.FloatTensor(memory["rewards"].reshape(-1, 1)).to(self.device)
Expand All @@ -112,5 +113,5 @@ def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray:
return new_priorities

def synchronize(self, new_params: List[np.ndarray]):
"""Synchronize worker dqn with learner dqn"""
"""Synchronize worker dqn with learner dqn."""
self._synchronize(self.dqn, new_params)

0 comments on commit b954f4d

Please sign in to comment.