Skip to content

Commit

Permalink
Finish implementing basic skeleton of automatic launching based on py…
Browse files Browse the repository at this point in the history
…torch lightning.
  • Loading branch information
iffiX committed Mar 23, 2021
1 parent 92b1dab commit d0d6faf
Show file tree
Hide file tree
Showing 26 changed files with 853 additions and 150 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ See [here](https://machin.readthedocs.io/). Examples are located in [examples](h

### Installation
---
Machin is hosted on [PyPI](https://pypi.org/project/machin/). Python >= 3.5 and PyTorch >= 1.5.0 is required. You may install the Machin library by simply typing:
Machin is hosted on [PyPI](https://pypi.org/project/machin/). Python >= 3.5 and PyTorch >= 1.6.0 is required. You may install the Machin library by simply typing:
```
pip install machin
```
Expand Down
28 changes: 21 additions & 7 deletions machin/frame/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def forward(self, state, action=None):
self.discount = discount
self.value_weight = value_weight
self.entropy_weight = entropy_weight
self.grad_max = gradient_max
self.gradient_max = gradient_max
self.gae_lambda = gae_lambda
self.normalize_advantage = normalize_advantage
self.visualize = visualize
Expand Down Expand Up @@ -209,6 +209,20 @@ def forward(self, state, action=None):

super(A2C, self).__init__()

@property
def optimizers(self):
return [self.actor_optim, self.critic_optim]

@optimizers.setter
def optimizers(self, optimizers):
self.actor_optim, self.critic_optim = optimizers

@property
def lr_schedulers(self):
if hasattr(self, "actor_lr_sch") and hasattr(self, "critic_lr_sch"):
return [self.actor_lr_sch, self.critic_lr_sch]
return []

def act(self, state: Dict[str, Any], *_, **__):
"""
Use actor network to give a policy to the current state.
Expand Down Expand Up @@ -365,9 +379,9 @@ def update(self,
# Update actor network
if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
self._backward(act_policy_loss)
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
self.actor.parameters(), self.gradient_max
)
self.actor_optim.step()

Expand Down Expand Up @@ -396,9 +410,9 @@ def update(self,
# Update critic network
if update_value:
self.critic.zero_grad()
value_loss.backward()
self._backward(value_loss)
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
self.critic.parameters(), self.gradient_max
)
self.critic_optim.step()

Expand All @@ -417,8 +431,8 @@ def update_lr_scheduler(self):
if hasattr(self, "critic_lr_sch"):
self.critic_lr_sch.step()

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"models": ["Actor", "Critic"],
"model_args": ((), ()),
Expand Down
16 changes: 14 additions & 2 deletions machin/frame/algorithms/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def __init__(self,
grad_server[0], grad_server[1]
self.is_syncing = True

@property
def optimizers(self):
return []

@optimizers.setter
def optimizers(self, optimizers):
pass

@property
def lr_schedulers(self):
return []

def set_sync(self, is_syncing):
self.is_syncing = is_syncing

Expand Down Expand Up @@ -147,8 +159,8 @@ def update(self,
self.actor_grad_server.push(self.actor)
self.critic_grad_server.push(self.critic)

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"grad_server_group_name": "a3c_grad_server",
"grad_server_members": "all",
Expand Down
16 changes: 12 additions & 4 deletions machin/frame/algorithms/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def __init__(self,
self.qnet_model_server = model_server[0]
self.is_syncing = True

@classmethod
def is_distributed(cls):
return True

def set_sync(self, is_syncing):
self.is_syncing = is_syncing

Expand Down Expand Up @@ -138,8 +142,8 @@ def update(self,
self.qnet_model_server.push(self.qnet)
return result

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"learner_process_ratio": 0.1,
"model_server_group_name": "dqn_apex_model_server",
Expand Down Expand Up @@ -314,6 +318,10 @@ def __init__(self,
self.actor_model_server = model_server[0]
self.is_syncing = True

@classmethod
def is_distributed(cls):
return True

def set_sync(self, is_syncing):
self.is_syncing = is_syncing

Expand Down Expand Up @@ -382,8 +390,8 @@ def update(self,
self.actor_model_server.push(self.actor)
return result

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"learner_process_ratio": 0.1,
"model_server_group_name": "ddpg_apex_model_server",
Expand Down
22 changes: 20 additions & 2 deletions machin/frame/algorithms/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,24 @@ def __init__(self,
self._reset_reward_dict()
super(ARS, self).__init__()

@property
def optimizers(self):
return [self.actor_optim]

@optimizers.setter
def optimizers(self, optimizers):
self.actor_optim = optimizers[0]

@property
def lr_schedulers(self):
if hasattr(self, "actor_lr_sch"):
return [self.actor_lr_sch]
return []

@classmethod
def is_distributed(cls):
return True

def get_actor_types(self) -> List[str]:
"""
Returns:
Expand Down Expand Up @@ -682,8 +700,8 @@ def _generate_parameter(self):
self.actor_with_delta[(r_idx, False)] = actor_negative
self.actor_with_delta[(r_idx, True)] = actor_positive

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"model_server_group_name": "ars_model_server",
"model_server_members": "all",
Expand Down
62 changes: 57 additions & 5 deletions machin/frame/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os.path import join
from typing import Dict, Any
from typing import Dict, Any, Callable
from torchviz import make_dot
import torch as t

Expand All @@ -16,14 +16,66 @@ class TorchFramework:

def __init__(self):
self._visualized = set()
self._backward = t.autograd.backward

@property
def optimizers(self):
raise NotImplementedError

@optimizers.setter
def optimizers(self, optimizers):
raise NotImplementedError

@property
def lr_schedulers(self):
raise NotImplementedError


@property
def top_models(self):
models = []
for m in self._is_top:
models.append(getattr(self, m))
return models

@property
def restorable_models(self):
models = []
for m in self._is_restorable:
models.append(getattr(self, m))
return models

@classmethod
def get_restorable(cls):
def get_top_model_names(cls):
"""
Get restorable modules.
Get attribute name of top level nn models.
"""
return cls._is_top

@classmethod
def get_restorable_model_names(cls):
"""
Get attribute name of restorable nn models.
"""
return cls._is_restorable

@classmethod
def is_distributed(cls):
"""
Whether this framework is a distributed framework which require
multiple processes to run, and depends on ``torch.distributed`` or
``torch.distributed.rpc``
"""
return False

def set_backward_function(self, backward_func: Callable):
"""
Replace the default backward function with a custom function.
The default loss backward function is ``torch.autograd.backward``
"""
assert callable(backward_func), "Backward function must be callable."
self._backward = backward_func

def enable_multiprocessing(self):
"""
Enable multiprocessing for all modules.
Expand Down Expand Up @@ -113,8 +165,8 @@ def visualize_model(self,
cleanup=False,
quiet=True)

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
raise NotImplementedError

@classmethod
Expand Down
28 changes: 21 additions & 7 deletions machin/frame/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self,
self.update_rate = update_rate
self.update_steps = update_steps
self.discount = discount
self.grad_max = gradient_max
self.gradient_max = gradient_max
self.visualize = visualize
self.visualize_dir = visualize_dir
self._update_counter = 0
Expand Down Expand Up @@ -174,6 +174,20 @@ def __init__(self,
self.criterion = criterion
super(DDPG, self).__init__()

@property
def optimizers(self):
return [self.actor_optim, self.critic_optim]

@optimizers.setter
def optimizers(self, optimizers):
self.actor_optim, self.critic_optim = optimizers

@property
def lr_schedulers(self):
if hasattr(self, "actor_lr_sch") and hasattr(self, "critic_lr_sch"):
return [self.actor_lr_sch, self.critic_lr_sch]
return []

def act(self,
state: Dict[str, Any],
use_target: bool = False,
Expand Down Expand Up @@ -419,9 +433,9 @@ def update(self,

if update_value:
self.critic.zero_grad()
value_loss.backward()
self._backward(value_loss)
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
self.critic.parameters(), self.gradient_max
)
self.critic_optim.step()

Expand All @@ -440,9 +454,9 @@ def update(self,

if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
self._backward(act_policy_loss)
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
self.actor.parameters(), self.gradient_max
)
self.actor_optim.step()

Expand Down Expand Up @@ -505,8 +519,8 @@ def reward_function(reward, discount, next_value, terminal, _):
terminal = terminal.to(reward.device)
return reward + discount * ~terminal * next_value

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
default_values = {
"models": ["Actor", "Actor", "Critic", "Critic"],
"model_args": ((), (), (), ()),
Expand Down
12 changes: 6 additions & 6 deletions machin/frame/algorithms/ddpg_per.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def update(self,

if update_value:
self.critic.zero_grad()
value_loss.backward()
self._backward(value_loss)
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.grad_max
self.critic.parameters(), self.gradient_max
)
self.critic_optim.step()

Expand All @@ -148,9 +148,9 @@ def update(self,

if update_policy:
self.actor.zero_grad()
act_policy_loss.backward()
self._backward(act_policy_loss)
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.grad_max
self.actor.parameters(), self.gradient_max
)
self.actor_optim.step()

Expand All @@ -170,8 +170,8 @@ def update(self,
# use .item() to prevent memory leakage
return -act_policy_loss.item(), value_loss.item()

@staticmethod
def generate_config(config: Dict[str, Any]):
@classmethod
def generate_config(cls, config: Dict[str, Any]):
config = DDPG.generate_config(config)
config["frame"] = "DDPGPer"
return config
Loading

0 comments on commit d0d6faf

Please sign in to comment.