-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
federated.py
45 lines (41 loc) · 1.41 KB
/
federated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""XGBoost Federated Learning related API."""
from .core import _LIB, XGBoostError, _check_call, build_info, c_str
def run_federated_server(
port: int,
world_size: int,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
) -> None:
"""Run the Federated Learning server.
Parameters
----------
port : int
The port to listen on.
world_size: int
The number of federated workers.
server_key_path: str
Path to the server private key file. SSL is turned off if empty.
server_cert_path: str
Path to the server certificate file. SSL is turned off if empty.
client_cert_path: str
Path to the client certificate file. SSL is turned off if empty.
"""
if build_info()["USE_FEDERATED"]:
if not server_key_path or not server_cert_path or not client_cert_path:
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
else:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
"enabled in order to use this module"
)