Skip to content

Commit

Permalink
docs fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguanglei committed Dec 4, 2018
1 parent 48f4a2c commit 6b6fc1d
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions jdit/parallel/parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from multiprocessing import Pool
from types import FunctionType


class SupParallelTrainer(object):
""" Training parallel
Expand All @@ -16,13 +17,40 @@ class SupParallelTrainer(object):
.. note ::
You must set the value of ``task_id`` and ``gpu_ids_abs``, regardless in ``default_params`` or ``unfixed_params_list``.
You must set the value of ``task_id`` and ``gpu_ids_abs``, regardless in ``default_params`` or
``unfixed_params_list``.
``{'task_id': 1`}`` , ``{'gpu_ids_abs': [0,1]}`` .
* For the same ``task_id`` , the tasks will be executed **sequentially** on the certain devices.
* For the different ``task_id`` , the will be executed **parallelly** on the certain devices.
``{'task_id': 1`}`` , ``{'gpu_ids_abs': [0,1]}``
Example:
unfixed_params_list = [{'task_id':1, 'lr':1e-3,'gpu_ids_abs': [0] },
{'task_id':1, 'lr':1e-4,'gpu_ids_abs': [0] },
{'task_id':2, 'lr':1e-5,'gpu_ids_abs': [2,3] }]
This set of ``unfixed_params_list`` means that:
Grid table:
+------------+-----------------------+-----------------------+---------------------+
| time_step | 'task_id':1 | 'task_id':2 | |
+============+=======================+=======================+=====================+
| t | 'lr':1e-3, | 'lr':1e-5, | executed parallelly |
| | 'gpu_ids_abs': [0] | 'gpu_ids_abs': [2,3] | |
+------------+-----------------------+-----------------------+---------------------+
| t+1 | 'lr':1e-4, | \ | |
| | 'gpu_ids_abs': [0] | | |
+------------+-----------------------+-----------------------+---------------------+
| | executed sequentially | \ | |
+------------+-----------------------+-----------------------+---------------------+
"""
def __init__(self, default_params:dict, unfixed_params_list:list):

def __init__(self, default_params: dict, unfixed_params_list: list):
"""
:param default_params: a ``dict()`` like {param:v1, param:v2 ...}
Expand All @@ -46,7 +74,7 @@ def __init__(self, default_params:dict, unfixed_params_list:list):
# self.parallel_plans = {(task_id):[{param1},{param2}]}

@abstractmethod
def build_task_trainer(self, params:dict):
def build_task_trainer(self, params: dict):
"""You need to write this method to build your own ``Trainer``.
This will run in a certain subprocess.
Expand Down Expand Up @@ -119,7 +147,7 @@ def train(self, max_processes=4):
p.join()
print('All subprocesses done.')

def _start_train(self, parallel_plan:tuple, position:int):
def _start_train(self, parallel_plan: tuple, position: int):
task_id, candidate_params = parallel_plan
# task_id, candidate_params = parallel_planChild process ID:
nums_tasks = len(candidate_params)
Expand All @@ -131,7 +159,7 @@ def _start_train(self, parallel_plan:tuple, position:int):
trainer.train(process_bar_header=process_bar_header, process_bar_position=position, subbar_disable=True)
# print("<<< finish Task %d|%s" % (index, str(task_id)))

def _distribute_task_on_devices(self, candidate_params_list:list):
def _distribute_task_on_devices(self, candidate_params_list: list):
for params in candidate_params_list:
assert "gpu_ids_abs" in params and "task_id" in params, "You must pass params `gpu_ids_abs` to set device"
assert "task_id" in params, "You must pass params `task_id` to set a task ID"
Expand All @@ -147,7 +175,7 @@ def _distribute_task_on_devices(self, candidate_params_list:list):
# trainers_plan = list(gpu_used_plan.values) # [[t1,t2],[t3]...]
return tasks_plan

def _build_candidate_params(self, default_params:dict, unfixed_params_list:list):
def _build_candidate_params(self, default_params: dict, unfixed_params_list: list):
final_unfixed_params_list = self._add_logdirs_to_unfixed_params(unfixed_params_list)
total_params = []
import copy
Expand All @@ -158,7 +186,7 @@ def _build_candidate_params(self, default_params:dict, unfixed_params_list:list)
total_params.append(copy.deepcopy(params))
return total_params

def _add_logdirs_to_unfixed_params(self, unfixed_params_list:list):
def _add_logdirs_to_unfixed_params(self, unfixed_params_list: list):
import copy
final_unfixed_params_list = copy.deepcopy(unfixed_params_list)
use_auto_logdir = not "logdir" in unfixed_params_list[0]
Expand All @@ -183,7 +211,7 @@ def _add_logdirs_to_unfixed_params(self, unfixed_params_list:list):

return final_unfixed_params_list # [dir1, dir2, dir3]

def _convert_to_dirname(self, item:str):
def _convert_to_dirname(self, item: str):
dir_name = item.strip()
replace_dict = {"*": "",
">": "greater",
Expand Down

0 comments on commit 6b6fc1d

Please sign in to comment.