Skip to content

Commit

Permalink
Merge pull request #13 from ikamensh/py36-style
Browse files Browse the repository at this point in the history
Py36 style
  • Loading branch information
iffiX committed Mar 30, 2021
2 parents f2e0e5f + 9107aaf commit e0c079a
Show file tree
Hide file tree
Showing 90 changed files with 349 additions and 436 deletions.
2 changes: 1 addition & 1 deletion machin/auto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class RLDataset(IterableDataset):
"""

def __init__(self, **_kwargs):
super(RLDataset, self).__init__()
super().__init__()

def __iter__(self) -> Iterable:
return self
Expand Down
9 changes: 5 additions & 4 deletions machin/auto/env/openai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
render_every_episode: int = 100,
act_kwargs: Dict[str, Any] = None,
):
super(RLGymDiscActDataset, self).__init__()
super().__init__()
self.frame = frame
self.env = env
self.render_every_episode = render_every_episode
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
render_every_episode: int = 100,
act_kwargs: Dict[str, Any] = None,
):
super(RLGymContActDataset, self).__init__()
super().__init__()
self.frame = frame
self.env = env
self.render_every_episode = render_every_episode
Expand Down Expand Up @@ -297,8 +297,9 @@ def gym_env_dataset_creator(frame, env_config):
)


def launch_gym(config: Union[Dict[str, Any], Config],
pl_callbacks: List[Callback] = None):
def launch_gym(
config: Union[Dict[str, Any], Config], pl_callbacks: List[Callback] = None
):
"""
Args:
config: All configs needed to launch a gym environment and initialize
Expand Down
6 changes: 4 additions & 2 deletions machin/auto/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
env_dataset_creator: A callable which accepts the algorithm frame
and env config dictionary, and outputs a environment dataset.
"""
super(Launcher, self).__init__()
super().__init__()
self.config = config
self.env_dataset_creator = env_dataset_creator
self.frame = init_algorithm_from_config(config)
Expand Down Expand Up @@ -106,4 +106,6 @@ def _log(self, logs: List[Dict[str, Any]]):
log_val[1](self, log_key, log_val[0])
else:
is_dist_initialized = dist.is_available() and dist.is_initialized()
self.log(log_key, log_val, prog_bar=True, sync_dist=is_dist_initialized)
self.log(
log_key, log_val, prog_bar=True, sync_dist=is_dist_initialized
)
8 changes: 3 additions & 5 deletions machin/auto/pl_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LocalMediaLogger(LightningLoggerBase):
"""

def __init__(self, image_dir: str, artifact_dir: str):
super(LocalMediaLogger, self).__init__()
super().__init__()
self.image_dir = image_dir
self.artifact_dir = artifact_dir
self._counters = {}
Expand Down Expand Up @@ -86,12 +86,10 @@ def log_image(
self._counters[log_name] = 0

if not isinstance(image, str):
log_path = log_name + "_{}.png".format(step or self._counters[log_name])
log_path = log_name + f"_{step or self._counters[log_name]}.png"
else:
extension = os.path.splitext(image)[1]
log_path = log_name + "_{}{}".format(
step or self._counters[log_name], extension
)
log_path = log_name + f"_{step or self._counters[log_name]}{extension}"
self._counters[log_name] += 1

path = os.path.join(self.image_dir, log_path)
Expand Down
8 changes: 4 additions & 4 deletions machin/auto/pl_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
rank=int(global_rank),
world_size=int(world_size),
dist_backend=self.torch_distributed_backend,
dist_init_method="tcp://{}:{}".format(master_addr, master_port),
rpc_init_method="tcp://{}:{}".format(master_addr, int(master_port) + 1),
dist_init_method=f"tcp://{master_addr}:{master_port}",
rpc_init_method=f"tcp://{master_addr}:{int(master_port) + 1}",
)

def training_step(self, *args, **kwargs):
Expand Down Expand Up @@ -96,8 +96,8 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
rank=int(global_rank),
world_size=int(world_size),
dist_backend=self.torch_distributed_backend,
dist_init_method="tcp://{}:{}".format(master_addr, master_port),
rpc_init_method="tcp://{}:{}".format(master_addr, int(master_port) + 1),
dist_init_method=f"tcp://{master_addr}:{master_port}",
rpc_init_method=f"tcp://{master_addr}:{int(master_port) + 1}",
)

def training_step(self, *args, **kwargs):
Expand Down
12 changes: 5 additions & 7 deletions machin/env/wrappers/openai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class GymTerminationError(Exception):
def __init__(self):
super(GymTerminationError, self).__init__(
super().__init__(
"One or several environments have terminated, " "reset before continuing."
)

Expand All @@ -33,7 +33,7 @@ def __init__(self, env_creators: List[Callable[[int], gym.Env]]):
env_creators: List of gym environment creators, used to create
environments, accepts a index as your environment id.
"""
super(ParallelWrapperDummy, self).__init__()
super().__init__()
self._envs = [ec(i) for ec, i in zip(env_creators, range(len(env_creators)))]
self._terminal = np.zeros([len(self._envs)], dtype=np.bool)

Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(self, env_creators: List[Callable[[int], gym.Env]]) -> None:
environments on sub process workers, accepts a index as your
environment id.
"""
super(ParallelWrapperSubProc, self).__init__()
super().__init__()
self.workers = []

# Some environments will hang or collapse when using fork context.
Expand Down Expand Up @@ -369,12 +369,10 @@ def _call_gym_env_method(self, env_idxs, method, args=None, kwargs=None):
if worker.exitcode is None:
continue
if worker.exitcode == 2:
raise RuntimeError(
"Worker {} failed to create environment.".format(worker_id)
)
raise RuntimeError(f"Worker {worker_id} failed to create environment.")
elif worker.exitcode != 0:
raise RuntimeError(
"Worker {} exited with code {}.".format(worker_id, worker.exitcode)
f"Worker {worker_id} exited with code {worker.exitcode}."
)

for env_idx, i in zip(env_idxs, range(len(env_idxs))):
Expand Down
2 changes: 1 addition & 1 deletion machin/frame/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def forward(self, state, action=None):

self.criterion = criterion

super(A2C, self).__init__()
super().__init__()

@property
def optimizers(self):
Expand Down
10 changes: 5 additions & 5 deletions machin/frame/algorithms/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
"""
# Adam is just a placeholder here, the actual optimizer is
# set in parameter servers
super(A3C, self).__init__(
super().__init__(
actor,
critic,
FakeOptimizer,
Expand Down Expand Up @@ -134,27 +134,27 @@ def act(self, state: Dict[str, Any], **__):
# DOC INHERITED
if self.is_syncing:
self.actor_grad_server.pull(self.actor)
return super(A3C, self).act(state)
return super().act(state)

def _eval_act(self, state: Dict[str, Any], action: Dict[str, Any], **__):
# DOC INHERITED
if self.is_syncing:
self.actor_grad_server.pull(self.actor)
return super(A3C, self)._eval_act(state, action)
return super()._eval_act(state, action)

def _criticize(self, state: Dict[str, Any], *_, **__):
# DOC INHERITED
if self.is_syncing:
self.critic_grad_server.pull(self.critic)
return super(A3C, self)._criticize(state)
return super()._criticize(state)

def update(
self, update_value=True, update_policy=True, concatenate_samples=True, **__
):
# DOC INHERITED
org_sync = self.is_syncing
self.is_syncing = False
super(A3C, self).update(update_value, update_policy, concatenate_samples)
super().update(update_value, update_policy, concatenate_samples)
self.is_syncing = org_sync
self.actor_grad_server.push(self.actor)
self.critic_grad_server.push(self.critic)
Expand Down
24 changes: 10 additions & 14 deletions machin/frame/algorithms/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
gradient_max: Maximum gradient.
replay_size: Local replay buffer size of a single worker.
"""
super(DQNApex, self).__init__(
super().__init__(
qnet,
qnet_target,
optimizer,
Expand Down Expand Up @@ -116,7 +116,7 @@ def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__):
# DOC INHERITED
if self.is_syncing and not use_target:
self.qnet_model_server.pull(self.qnet)
return super(DQNApex, self).act_discrete(state, use_target)
return super().act_discrete(state, use_target)

def act_discrete_with_noise(
self,
Expand All @@ -128,17 +128,13 @@ def act_discrete_with_noise(
# DOC INHERITED
if self.is_syncing and not use_target:
self.qnet_model_server.pull(self.qnet)
return super(DQNApex, self).act_discrete_with_noise(
state, use_target, decay_epsilon
)
return super().act_discrete_with_noise(state, use_target, decay_epsilon)

def update(
self, update_value=True, update_target=True, concatenate_samples=True, **__
):
# DOC INHERITED
result = super(DQNApex, self).update(
update_value, update_target, concatenate_samples
)
result = super().update(update_value, update_target, concatenate_samples)
if isinstance(
self.qnet, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
):
Expand Down Expand Up @@ -316,7 +312,7 @@ def __init__(
gradient_max: Maximum gradient.
replay_size: Local replay buffer size of a single worker.
"""
super(DDPGApex, self).__init__(
super().__init__(
actor,
actor_target,
critic,
Expand Down Expand Up @@ -357,7 +353,7 @@ def act(self, state: Dict[str, Any], use_target: bool = False, **__):
# DOC INHERITED
if self.is_syncing and not use_target:
self.actor_model_server.pull(self.actor)
return super(DDPGApex, self).act(state, use_target)
return super().act(state, use_target)

def act_with_noise(
self,
Expand All @@ -371,7 +367,7 @@ def act_with_noise(
# DOC INHERITED
if self.is_syncing and not use_target:
self.actor_model_server.pull(self.actor)
return super(DDPGApex, self).act_with_noise(
return super().act_with_noise(
state,
noise_param=noise_param,
ratio=ratio,
Expand All @@ -383,15 +379,15 @@ def act_discrete(self, state: Dict[str, Any], use_target: bool = False, **__):
# DOC INHERITED
if self.is_syncing and not use_target:
self.actor_model_server.pull(self.actor)
return super(DDPGApex, self).act_discrete(state, use_target)
return super().act_discrete(state, use_target)

def act_discrete_with_noise(
self, state: Dict[str, Any], use_target: bool = False, **__
):
# DOC INHERITED
if self.is_syncing and not use_target:
self.actor_model_server.pull(self.actor)
return super(DDPGApex, self).act_discrete_with_noise(state, use_target)
return super().act_discrete_with_noise(state, use_target)

def update(
self,
Expand All @@ -402,7 +398,7 @@ def update(
**__
):
# DOC INHERITED
result = super(DDPGApex, self).update(
result = super().update(
update_value, update_policy, update_target, concatenate_samples
)
if isinstance(
Expand Down
20 changes: 9 additions & 11 deletions machin/frame/algorithms/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


class RunningStat(object):
class RunningStat:
"""
Running status estimator method by B. P. Welford
described in http://www.johndcook.com/blog/standard_deviation/
Expand Down Expand Up @@ -128,7 +128,7 @@ def shape(self):
return self._M.shape


class MeanStdFilter(object):
class MeanStdFilter:
"""Keeps track of a running mean for seen states"""

def __init__(self, shape):
Expand Down Expand Up @@ -240,7 +240,7 @@ def __repr__(self):
)


class SharedNoiseSampler(object):
class SharedNoiseSampler:
def __init__(self, noise: t.Tensor, seed: int):
"""
Args:
Expand Down Expand Up @@ -293,7 +293,7 @@ def __init__(
normalize_state: bool = True,
noise_seed: int = 12345,
sample_seed: int = 123,
**__
**__,
):
"""
Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(
self._sync_actor()
self._generate_parameter()
self._reset_reward_dict()
super(ARS, self).__init__()
super().__init__()

@property
def optimizers(self):
Expand Down Expand Up @@ -520,12 +520,12 @@ def update(self):

# collect result in manager process
self.ars_group.pair(
"ars/rollout_result/{}".format(self.ars_group.get_cur_name()),
f"ars/rollout_result/{self.ars_group.get_cur_name()}",
[pos_reward, neg_reward, delta_idx],
)
if self.normalize_state:
self.ars_group.pair(
"ars/filter/{}".format(self.ars_group.get_cur_name()), self.filter
f"ars/filter/{self.ars_group.get_cur_name()}", self.filter
)
self.ars_group.barrier()

Expand Down Expand Up @@ -587,11 +587,9 @@ def update(self):
self.filter[k].clear_local()

self.ars_group.barrier()
self.ars_group.unpair(
"ars/rollout_result/{}".format(self.ars_group.get_cur_name())
)
self.ars_group.unpair(f"ars/rollout_result/{self.ars_group.get_cur_name()}")
if self.normalize_state:
self.ars_group.unpair("ars/filter/{}".format(self.ars_group.get_cur_name()))
self.ars_group.unpair(f"ars/filter/{self.ars_group.get_cur_name()}")
self.ars_group.barrier()

# synchronize filter states across all workers (and the manager)
Expand Down
4 changes: 2 additions & 2 deletions machin/frame/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ def save(
if r in network_map:
t.save(
getattr(self, r),
join(model_dir, "{}_{}.pt".format(network_map[r], version)),
join(model_dir, f"{network_map[r]}_{version}.pt"),
)
else:
default_logger.warning(
'Save name for module "{}" is not '
"specified, module name is used.".format(r)
)
t.save(getattr(self, r), join(model_dir, "{}_{}.pt".format(r, version)))
t.save(getattr(self, r), join(model_dir, f"{r}_{version}.pt"))

def visualize_model(self, final_tensor: t.Tensor, name: str, directory: str):
if name in self._visualized:
Expand Down
4 changes: 2 additions & 2 deletions machin/frame/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
)

self.criterion = criterion
super(DDPG, self).__init__()
super().__init__()

@property
def optimizers(self):
Expand Down Expand Up @@ -479,7 +479,7 @@ def load(
self, model_dir: str, network_map: Dict[str, str] = None, version: int = -1
):
# DOC INHERITED
super(DDPG, self).load(model_dir, network_map, version)
super().load(model_dir, network_map, version)
with t.no_grad():
hard_update(self.actor, self.actor_target)
hard_update(self.critic, self.critic_target)
Expand Down
Loading

0 comments on commit e0c079a

Please sign in to comment.