Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results Ack Sending #571

Merged
merged 10 commits into from Aug 17, 2021
26 changes: 24 additions & 2 deletions funcx_endpoint/funcx_endpoint/endpoint/interchange.py
Expand Up @@ -21,12 +21,13 @@
from parsl.version import VERSION as PARSL_VERSION

from funcx_endpoint.executors.high_throughput.messages import Message, COMMAND_TYPES, MessageType, Task
from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode
from funcx_endpoint.executors.high_throughput.messages import EPStatusReport, Heartbeat, TaskStatusCode, ResultsAck
from funcx.sdk.client import FuncXClient
from funcx import set_file_logger
from funcx_endpoint.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch
from funcx.serialize import FuncXSerializer
from funcx_endpoint.endpoint.taskqueue import TaskQueue
from funcx_endpoint.endpoint.results_ack import ResultsAckHandler
from queue import Queue

LOOP_SLOWDOWN = 0.0 # in seconds
Expand Down Expand Up @@ -175,6 +176,8 @@ def __init__(self,
self.total_pending_task_count = 0
self.fxs = FuncXClient()

self.results_ack_handler = ResultsAckHandler()

logger.info("Interchange address is {}".format(self.interchange_address))

self.endpoint_id = endpoint_id
Expand Down Expand Up @@ -288,6 +291,9 @@ def migrate_tasks_to_internal(self, kill_event, status_request):
task_counter += 1
logger.debug(f"[TASK_PULL_THREAD] Task counter:{task_counter} Pending Tasks: {self.total_pending_task_count}")

elif isinstance(msg, ResultsAck):
self.results_ack_handler.ack(msg.task_id)

else:
logger.warning(f"[TASK_PULL_THREAD] Unknown message type received: {msg}")

Expand Down Expand Up @@ -400,6 +406,16 @@ def start(self):
set_hwm=True)
self.results_outgoing.put('forwarder', pickle.dumps({"registration": self.endpoint_id}))

# TODO: this resend must happen after any endpoint re-registration to
# ensure there are not unacked results left
resend_results_messages = self.results_ack_handler.get_unacked_results_list()
if len(resend_results_messages) > 0:
logger.info(f"[MAIN] Resending {len(resend_results_messages)} previously unacked results")

# TODO: this should be a multipart send rather than a loop
for results in resend_results_messages:
self.results_outgoing.put('forwarder', results)

executor = list(self.executors.values())[0]
last = time.time()

Expand All @@ -415,6 +431,8 @@ def start(self):
logger.exception("[MAIN] Sending heartbeat to the forwarder over the results channel has failed")
raise

self.results_ack_handler.check_ack_counts()

try:
task = self.pending_task_queue.get(block=True, timeout=0.01)
executor.submit_raw(task.pack())
Expand All @@ -427,8 +445,12 @@ def start(self):
try:
results = self.results_passthrough.get(False, 0.01)

task_id = results["task_id"]
if task_id:
self.results_ack_handler.put(task_id, results["message"])

# results will be a pickled dict with task_id, container_id, and results/exception
self.results_outgoing.put('forwarder', results)
self.results_outgoing.put('forwarder', results["message"])
logger.info("Passing result to forwarder")

except queue.Empty:
Expand Down
74 changes: 74 additions & 0 deletions funcx_endpoint/funcx_endpoint/endpoint/results_ack.py
@@ -0,0 +1,74 @@
import logging
import time

logger = logging.getLogger(__name__)


class ResultsAckHandler():
"""
Tracks task results by task ID, discarding results after they have been ack'ed
"""

def __init__(self):
""" Initialize results storage and timing for log updates
"""
self.unacked_results = {}
# how frequently to log info about acked and unacked results
self.log_period = 60
self.last_log_timestamp = time.time()
self.acked_count = 0

def put(self, task_id, message):
""" Put sent task result into Unacked Dict

Parameters
----------
task_id : str
Task ID

message : pickled Dict
Results message
"""
self.unacked_results[task_id] = message

def ack(self, task_id):
""" Ack a task result that was sent. Nothing happens if the task ID is not
present in the Unacked Dict

Parameters
----------
task_id : str
Task ID to ack
"""
acked_task = self.unacked_results.pop(task_id, None)
if acked_task:
self.acked_count += 1
unacked_count = len(self.unacked_results)
logger.info(f"Acked task {task_id}, Unacked count: {unacked_count}")

def check_ack_counts(self):
""" Log the number of currently Unacked tasks and the tasks Acked since
the last check
"""
now = time.time()
if now - self.last_log_timestamp > self.log_period:
unacked_count = len(self.unacked_results)
logger.info(f"Unacked count: {unacked_count}, Acked results since last check {self.acked_count}")
self.acked_count = 0
self.last_log_timestamp = now

def get_unacked_results_list(self):
""" Get a list of unacked results messages that can be used for resending

Returns
-------
List of pickled Dicts
Unacked results messages
"""
return list(self.unacked_results.values())

def persist(self):
""" Save unacked results to disk
"""
# TODO: pickle dump unacked_results
return
Expand Up @@ -514,7 +514,10 @@ def _queue_management_worker(self):
elif isinstance(msgs, EPStatusReport):
logger.debug("[MTHREAD] Received EPStatusReport {}".format(msgs))
if self.passthrough:
self.results_passthrough.put(pickle.dumps(msgs))
self.results_passthrough.put({
"task_id": None,
"message": pickle.dumps(msgs)
})

else:
logger.debug("[MTHREAD] Unpacking results")
Expand Down Expand Up @@ -547,7 +550,13 @@ def _queue_management_worker(self):

if self.passthrough is True:
logger.debug(f"[MTHREAD] Pushing results for task:{tid}")
x = self.results_passthrough.put(serialized_msg)
# we are only interested in actual task ids here, not identifiers
# for other message types
sent_task_id = tid if isinstance(tid, str) else None
x = self.results_passthrough.put({
"task_id": sent_task_id,
"message": serialized_msg
})
logger.debug(f"[MTHREAD] task:{tid} ret value: {x}")
logger.debug(f"[MTHREAD] task:{tid} items in queue: {self.results_passthrough.qsize()}")
continue
Expand Down
Expand Up @@ -13,6 +13,7 @@ class MessageType(Enum):
EP_STATUS_REPORT = auto()
MANAGER_STATUS_REPORT = auto()
TASK = auto()
RESULTS_ACK = auto()

def pack(self):
return MESSAGE_TYPE_FORMATTER.pack(self.value)
Expand Down Expand Up @@ -70,6 +71,8 @@ def unpack(cls, msg):
return ManagerStatusReport.unpack(remaining)
elif message_type is MessageType.TASK:
return Task.unpack(remaining)
elif message_type is MessageType.RESULTS_ACK:
return ResultsAck.unpack(remaining)

raise Exception(f"Unknown Message Type Code: {message_type}")

Expand Down Expand Up @@ -201,3 +204,22 @@ def pack(self):
# TODO: do better than JSON?
jsonified = json.dumps(self.task_statuses)
return self.type.pack() + self.container_switch_count.to_bytes(10, 'little') + jsonified.encode("ascii")


class ResultsAck(Message):
"""
Results acknowledgement to acknowledge a task result was received by
the forwarder. Sent from forwarder->interchange
"""
type = MessageType.RESULTS_ACK

def __init__(self, task_id):
super().__init__()
self.task_id = task_id

@classmethod
def unpack(cls, msg):
return cls(msg.decode("ascii"))

def pack(self):
return self.type.pack() + self.task_id.encode("ascii")