Skip to content

Commit

Permalink
Add client controller executor (NVIDIA#2530)
Browse files Browse the repository at this point in the history
* add client controller executor

* address comments

* enhance abort, set peer props

* remove asserts

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
2 people authored and nvidianz committed May 14, 2024
1 parent b43241b commit 918f248
Show file tree
Hide file tree
Showing 15 changed files with 765 additions and 41 deletions.
183 changes: 183 additions & 0 deletions job_templates/sag_cse_ccwf_pt/config_fed_client.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Path to defined PyTorch network
# This assumes that there will be a "net.py" file with class name "Net", please modify accordingly
model_class_path = "net.Net"

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = [
"train",
"validate",
"submit_model"
]

# This particular executor
executor {

# This is an executor for pytorch + Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

args {

# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "pytorch"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"
# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

train_task_name = "train"
evaluate_task_name = "validate"
submit_model_task_name = "submit_model"
}
}
}
{
# All tasks prefixed with wf_ are routed to this ClientControllerExecutor
tasks = ["wf_*"]
executor {
id = "client_controller_executor"
path = "nvflare.app_common.ccwf.client_controller_executor.ClientControllerExecutor"
# ClientControllerExecutor for running controllers on client-side.
args {
# list of controller ids from components to be run in order
controller_id_list = ["sag_ctl", "cse_ctl"]
task_name_prefix = "wf"
# persistor used to distribute and save final results for clients
persistor_id = "persistor"
}
}
}
]

# Array of task data filters. If provided, it will control the data from client controller to client executor
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction
task_data_filters = []

# Array of task result filters. If provided, it will control the data from client executor to client controller
# Filter direction (in, out, inout) can be set as since clients send tasks to each other, a task has both a sending (out) and a receiving (in) direction
task_result_filters = []

components = [
{
id = "sag_ctl"
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
args {
min_clients = 2
num_rounds = 3
start_round = 0
wait_time_after_min_received = 0
aggregator_id = "aggregator"
persistor_id = "persistor"
shareable_generator_id = "shareable_generator"
train_task_name = "train"
train_timeout = 0
}
}
{
id = "cse_ctl",
path = "nvflare.app_common.workflows.cross_site_model_eval.CrossSiteModelEval",
args {
model_locator_id = "model_locator",
submit_model_timeout = 600,
validation_timeout = 6000,
cleanup_models = false
}
}
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"

path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"

args {
# Mode of the endpoint. A pipe has two endpoints.
# An endpoint can be either the one that initiates communication or the one listening.
# PASSIVE is the one listening.
mode = "PASSIVE"

# root_path: is the directory location of the parameters exchange.
# You can also set it to an absolute path in your system.
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
# required components for the client-controlled workflow defined on client-side
{
id = "persistor"
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"
args.model.path = "{model_class_path}"
}
{
id = "shareable_generator"
path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator"
args = {}
}
{
# This is the aggregator that perform the weighted average aggregation.
# the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data.
id = "aggregator"
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator"
args.expected_data_kind = "WEIGHTS"
},
{
id = "model_locator"
name = "PTFileModelLocator"
args {
pt_persistor_id = "persistor"
}
},
{
# This component is not directly used in Workflow.
# it select the best model based on the incoming global validation metrics.
id = "model_selector"
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "json_generator"
name = "ValidationJsonGenerator"
args {}
}
]
}
39 changes: 39 additions & 0 deletions job_templates/sag_cse_ccwf_pt/config_fed_server.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
# version of the configuration
format_version = 2

# task data filter: if filters are provided, the filter will filter the data flow out of server to client.
task_data_filters =[]

# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
task_result_filters = []

# This assumes that there will be a "net.py" file with class name "Net".
# If your model code is not in "net.py" and class name is not "Net", please modify here
model_class_path = "net.Net"

# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
workflows = [
{
# server-side controller to manage job life cycle and configuration
id = "svr_ctl"
path = "nvflare.app_common.ccwf.server_ctl.ServerSideController"
args {
# the prefix for task names of this workflow
task_name_prefix = "wf"
# the maximum amount of time allowed for a client to miss a status report
max_status_report_interval = 300
# policy to choose which client to run the controller logic from
starting_client_policy = "random"
# timeout for the ClientControllerExecutor start task, which runs all of the controllers
start_task_timeout = 600
}
}
]

# List of components used in the server side workflow.
components = [
]

}
5 changes: 5 additions & 0 deletions job_templates/sag_cse_ccwf_pt/info.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
description = "Client Controller FedAvg and cross-site evaluation with PyTorch"
execution_api_type = "client_api"
controller_type = "client"
}
11 changes: 11 additions & 0 deletions job_templates/sag_cse_ccwf_pt/info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Job Template Information Card

## sag_cse_ccwf_pt
name = "sag_cse_ccwf_pt"
description = "Client Controller FedAvg with scatter & gather workflow and cross-site evaluation with PyTorch"
class_name = "ClientControllerExecutor"
controller_type = "client"
executor_type = "launcher_executor"
contributor = "NVIDIA"
init_publish_date = "2024-04-25"
last_updated_date = "2024-04-25"
8 changes: 8 additions & 0 deletions job_templates/sag_cse_ccwf_pt/meta.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "sag_cse_ccwf_pt"
resource_spec {}
min_clients = 2
deploy_map {
app = [
"@ALL"
]
}
2 changes: 1 addition & 1 deletion nvflare/apis/controller_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
name (str): name of the task
data (Shareable): data of the task
props: Any additional properties of the task
timeout: How long this task will last. If == 0, the task never time out.
timeout: How long this task will last. If == 0, the task never time out (WFCommServer-> never time out, WFCommClient-> time out after `max_task_timeout`).
before_task_sent_cb: If provided, this callback would be called before controller sends the tasks to clients.
It needs to follow the before_task_sent_cb_signature.
after_task_sent_cb: If provided, this callback would be called after controller sends the tasks to clients.
Expand Down
32 changes: 22 additions & 10 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def broadcast(
min_responses: int = 1,
wait_time_after_min_received: int = 0,
):
self.communicator.broadcast(task, fl_ctx, targets, min_responses, wait_time_after_min_received)
return self.communicator.broadcast(task, fl_ctx, targets, min_responses, wait_time_after_min_received)

def broadcast_and_wait(
self,
Expand All @@ -71,12 +71,12 @@ def broadcast_and_wait(
wait_time_after_min_received: int = 0,
abort_signal: Optional[Signal] = None,
):
self.communicator.broadcast_and_wait(
return self.communicator.broadcast_and_wait(
task, fl_ctx, targets, min_responses, wait_time_after_min_received, abort_signal
)

def broadcast_forever(self, task: Task, fl_ctx: FLContext, targets: Union[List[Client], List[str], None] = None):
self.communicator.broadcast_forever(task, fl_ctx, targets)
return self.communicator.broadcast_forever(task, fl_ctx, targets)

def send(
self,
Expand All @@ -86,7 +86,7 @@ def send(
send_order: SendOrder = SendOrder.SEQUENTIAL,
task_assignment_timeout: int = 0,
):
self.communicator.send(task, fl_ctx, targets, send_order, task_assignment_timeout)
return self.communicator.send(task, fl_ctx, targets, send_order, task_assignment_timeout)

def send_and_wait(
self,
Expand All @@ -97,7 +97,7 @@ def send_and_wait(
task_assignment_timeout: int = 0,
abort_signal: Signal = None,
):
self.communicator.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout, abort_signal)
return self.communicator.send_and_wait(task, fl_ctx, targets, send_order, task_assignment_timeout, abort_signal)

def relay(
self,
Expand All @@ -109,7 +109,7 @@ def relay(
task_result_timeout: int = 0,
dynamic_targets: bool = True,
):
self.communicator.relay(
return self.communicator.relay(
task, fl_ctx, targets, send_order, task_assignment_timeout, task_result_timeout, dynamic_targets
)

Expand All @@ -124,7 +124,7 @@ def relay_and_wait(
dynamic_targets: bool = True,
abort_signal: Optional[Signal] = None,
):
self.communicator.relay_and_wait(
return self.communicator.relay_and_wait(
task,
fl_ctx,
targets,
Expand All @@ -136,15 +136,22 @@ def relay_and_wait(
)

def get_num_standing_tasks(self) -> int:
return self.communicator.get_num_standing_tasks()
try:
return self.communicator.get_num_standing_tasks()
except Exception as e:
self.logger.warning(f"get_num_standing_tasks() is not supported by {self.communicator}: {e}")
return None

def cancel_task(
self, task: Task, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None
):
self.communicator.cancel_task(task, completion_status, fl_ctx)

def cancel_all_tasks(self, completion_status=TaskCompletionStatus.CANCELLED, fl_ctx: Optional[FLContext] = None):
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
try:
self.communicator.cancel_all_tasks(completion_status, fl_ctx)
except Exception as e:
self.log_warning(fl_ctx, f"cancel_all_tasks() is not supported by {self.communicator}: {e}")

def get_client_disconnect_time(self, client_name):
"""Get the time when the client is deemed disconnected.
Expand All @@ -157,4 +164,9 @@ def get_client_disconnect_time(self, client_name):
"""
if not self.communicator:
return None
return self.communicator.get_client_disconnect_time(client_name)

try:
return self.communicator.get_client_disconnect_time(client_name)
except Exception as e:
self.logger.warning(f"get_client_disconnect_time() is not supported by {self.communicator}: {e}")
return None
Loading

0 comments on commit 918f248

Please sign in to comment.