Skip to content

Commit

Permalink
[python] add network config api (#1019)
Browse files Browse the repository at this point in the history
* add network

* update doc
  • Loading branch information
wxchan authored and guolinke committed Oct 26, 2017
1 parent 36f4c13 commit 95519f3
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
Whether to print messages during construction.
"""
self.handle = None
self.network = False
self.__need_reload_eval_info = True
self.__train_data_name = "training"
self.__attr = {}
Expand Down Expand Up @@ -1288,6 +1289,20 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
"""set network if necessary"""
if "machines" in params:
machines = params["machines"]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
elif model_file is not None:
"""Prediction task"""
out_num_iterations = ctypes.c_int(0)
Expand All @@ -1308,6 +1323,8 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
raise TypeError('Need at least one training dataset or model file to create booster instance')

def __del__(self):
if self.network:
self.free_network()
if self.handle is not None:
_safe_call(_LIB.LGBM_BoosterFree(self.handle))

Expand Down Expand Up @@ -1351,6 +1368,32 @@ def _free_buffer(self):
self.__inner_predict_buffer = []
self.__is_predicted_cur_iter = []

def set_network(self, machines, local_listen_port=12400,
listen_time_out=120, num_machines=1):
"""Set the network configuration.
Parameters
----------
machines: list, set or string
Names of machines.
local_listen_port: int, optional (default=12400)
TCP listen port for local machines.
listen_time_out: int, optional (default=120)
Socket time-out in minutes.
num_machines: int, optional (default=1)
The number of machines for parallel learning application.
"""
_safe_call(_LIB.LGBM_NetworkInit(c_str(machines),
ctypes.c_int(local_listen_port),
ctypes.c_int(listen_time_out),
ctypes.c_int(num_machines)))
self.network = True

def free_network(self):
"""Free Network."""
_safe_call(_LIB.LGBM_NetworkFree())
self.network = False

def set_train_data_name(self, name):
"""Set the name to the training Dataset.
Expand Down

0 comments on commit 95519f3

Please sign in to comment.