Skip to content

Commit

Permalink
Separate time spans from timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Dec 29, 2023
1 parent 6ad3625 commit abe3084
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 78 deletions.
93 changes: 53 additions & 40 deletions colmena/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def deserialize(method: 'SerializationMethod', message: str) -> Any:


def _serialized_str_to_bytes_shim(
s: str,
method: Union[str, SerializationMethod],
s: str,
method: Union[str, SerializationMethod],
) -> bytes:
"""Shim between Colmena serialized objects and bytes.
Expand Down Expand Up @@ -99,8 +99,8 @@ def _serialized_str_to_bytes_shim(


def _serialized_bytes_to_obj_wrapper(
b: str,
method: Union[str, SerializationMethod],
b: str,
method: Union[str, SerializationMethod],
) -> Any:
"""Wrapper which converts bytes to strings before deserializing.
Expand Down Expand Up @@ -161,6 +161,42 @@ def total_ranks(self) -> int:
return self.node_count * self.cpu_processes


class Timestamps(BaseModel):
"""A class which records the system times at which key events in a task occurred
All should be in UTC.
"""

created: float = Field(description="Time this value object was created",
default_factory=lambda: datetime.now().timestamp())
input_received: float = Field(None, description="Time the inputs was received by the task server")
compute_started: float = Field(None, description="Time workflow process began executing a task")
compute_ended: float = Field(None, description="Time workflow process finished executing a task")
result_sent: float = Field(None, description="Time message was sent from the server")
result_received: float = Field(None, description="Time value was received by client")
start_task_submission: float = Field(None, description="Time marking the start of the task submission to workflow engine")
task_received: float = Field(None, description="Time task result received from workflow engine")


class TimeSpans(BaseModel):
"""Amount of time elapsed between major events
All are recorded in seconds
"""

running: float = Field(None, description="Runtime of the method, if available")
serialize_inputs: float = Field(None, description="Time required to serialize inputs on client")
deserialize_inputs: float = Field(None, description="Time required to deserialize inputs on worker")
serialize_results: float = Field(None, description="Time required to serialize results on worker")
deserialize_results: float = Field(None, description="Time required to deserialize results on client")
async_resolve_proxies: float = Field(None, description="Time required to start async resolves of proxies")
proxy: Dict[str, Dict[str, dict]] = Field(default_factory=dict,
description='Timings related to resolving ProxyStore proxies on the compute worker')

additional: Dict[str, float] = Field(default_factory=dict,
description="Additional timings reported by a task server")


class Result(BaseModel):
"""A class which describes the inputs and results of the calculations evaluated by the MethodServer
Expand All @@ -187,31 +223,12 @@ class Result(BaseModel):
resources: ResourceRequirements = Field(default_factory=ResourceRequirements, help='List of the resources required for a task, if desired')
failure_info: Optional[FailureInformation] = Field(None, description="Messages about task failure. Provided by Task Server")
worker_info: Optional[WorkerInformation] = Field(None, description="Information about the worker which executed a task. Provided by Task Server")

# Performance tracking
time_created: float = Field(None, description="Time this value object was created")
time_input_received: float = Field(None, description="Time the inputs was received by the task server")
time_compute_started: float = Field(None, description="Time workflow process began executing a task")
time_compute_ended: float = Field(None, description="Time workflow process finished executing a task")
time_result_sent: float = Field(None, description="Time message was sent from the server")
time_result_received: float = Field(None, description="Time value was received by client")
time_start_task_submission: float = Field(None, description="Time marking the start of the task submission to workflow engine")
time_task_received: float = Field(None, description="Time task result received from workflow engine")

time_running: float = Field(None, description="Runtime of the method, if available")
time_serialize_inputs: float = Field(None, description="Time required to serialize inputs on client")
time_deserialize_inputs: float = Field(None, description="Time required to deserialize inputs on worker")
time_serialize_results: float = Field(None, description="Time required to serialize results on worker")
time_deserialize_results: float = Field(None, description="Time required to deserialize results on client")
time_async_resolve_proxies: float = Field(None,
description="Time required to scan function inputs and start async resolves of proxies")

additional_timing: dict = Field(default_factory=dict,
description="Timings recorded by a TaskServer that are not defined by above")
proxy_timing: Dict[str, Dict[str, dict]] = Field(default_factory=dict,
description='Timings related to resolving ProxyStore proxies on the compute worker')
message_sizes: Dict[str, int] = Field(default_factory=dict, description='Sizes of the inputs and results in bytes')

# Timings
timestamp: Timestamps = Field(default_factory=Timestamps, help='Times at which major events occurred')
time: TimeSpans = Field(default_factory=TimeSpans, help='Elapsed time between major events')

# Serialization options
serialization_method: SerializationMethod = Field(SerializationMethod.JSON,
description="Method used to serialize input data")
Expand All @@ -229,10 +246,6 @@ def __init__(self, inputs: Tuple[Tuple[Any], Dict[str, Any]], **kwargs):
"""
super().__init__(inputs=inputs, **kwargs)

# Mark "created" only if the value is not already set
if 'time_created' not in kwargs:
self.time_created = datetime.now().timestamp()

@property
def args(self) -> Tuple[Any]:
return tuple(self.inputs[0])
Expand Down Expand Up @@ -277,31 +290,31 @@ def json(self, **kwargs: Dict[str, Any]) -> str:

def mark_result_received(self):
"""Mark that a completed computation was received by a client"""
self.time_result_received = datetime.now().timestamp()
self.timestamp.result_received = datetime.now().timestamp()

def mark_input_received(self):
"""Mark that a task server has received a value"""
self.time_input_received = datetime.now().timestamp()
self.timestamp.input_received = datetime.now().timestamp()

def mark_compute_started(self):
"""Mark that the compute for a method has started"""
self.time_compute_started = datetime.now().timestamp()
self.timestamp.compute_started = datetime.now().timestamp()

def mark_result_sent(self):
"""Mark when a result is sent from the task server"""
self.time_result_sent = datetime.now().timestamp()
self.timestamp.result_sent = datetime.now().timestamp()

def mark_start_task_submission(self):
"""Mark when the Task Server submits a task to the engine"""
self.time_start_task_submission = datetime.now().timestamp()
self.timestamp.start_task_submission = datetime.now().timestamp()

def mark_task_received(self):
"""Mark when the Task Server receives the task from the engine"""
self.time_task_received = datetime.now().timestamp()
self.timestamp.task_received = datetime.now().timestamp()

def mark_compute_ended(self):
"""Mark when the task finished executing"""
self.time_compute_ended = datetime.now().timestamp()
self.timestamp.compute_ended = datetime.now().timestamp()

def set_result(self, result: Any, runtime: float = None):
"""Set the value of this computation
Expand All @@ -318,7 +331,7 @@ def set_result(self, result: Any, runtime: float = None):
self.value = result
if not self.keep_inputs:
self.inputs = ((), {})
self.time_running = runtime
self.time.running = runtime
self.success = True

def serialize(self) -> Tuple[float, List[Proxy]]:
Expand Down Expand Up @@ -386,7 +399,7 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
proxies.append(value_proxy)

# Update the statistics
store_proxy_stats(value_proxy, self.proxy_timing)
store_proxy_stats(value_proxy, self.time.proxy)

# Serialize the proxy with Colmena's utilities. This is
# efficient since the proxy is just a reference and metadata
Expand Down
3 changes: 1 addition & 2 deletions colmena/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ def store_proxy_stats(proxy: Proxy, proxy_timing: dict):
# Get the key associated with this proxy
key = get_key(proxy)

# ProxyStore keys are NamedTuples so we cast to a string
# so we can use the key as a JSON key.
# ProxyStore keys are NamedTuples, so we cast to a string to use as a JSON key.
key = str(key)

# Get the store associated with this proxy
Expand Down
4 changes: 2 additions & 2 deletions colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_result(self, topic: str = 'default', timeout: Optional[float] = None) ->

# Parse the value and mark it as complete
result_obj = Result.parse_raw(message)
result_obj.time_deserialize_results = result_obj.deserialize()
result_obj.time.deserialize_results = result_obj.deserialize()
result_obj.mark_result_received()

# Some logging
Expand Down Expand Up @@ -238,7 +238,7 @@ def send_inputs(self,
)

# Push the serialized value to the task server
result.time_serialize_inputs, proxies = result.serialize()
result.time.serialize_inputs, proxies = result.serialize()
self._send_request(result.json(exclude_none=True), topic)
logger.info(f'Client sent a {method} task with topic {topic}. Created {len(proxies)} proxies for input values')

Expand Down
8 changes: 4 additions & 4 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
result.mark_compute_started()

# Unpack the inputs
result.time_deserialize_inputs = result.deserialize()
result.time.deserialize_inputs = result.deserialize()

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
Expand All @@ -186,7 +186,7 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
input_proxies.extend(resolve_proxies_async(arg))
for value in result.kwargs.values():
input_proxies.extend(resolve_proxies_async(value))
result.time_async_resolve_proxies = perf_counter() - start_time
result.time.async_resolve_proxies = perf_counter() - start_time

# Execute the function
start_time = perf_counter()
Expand Down Expand Up @@ -222,10 +222,10 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
result.mark_compute_ended()

# Re-pack the results. Will store the proxy statistics
result.time_serialize_results, _ = result.serialize()
result.time.serialize_results, _ = result.serialize()

# Get the statistics for the proxy resolution
for proxy in input_proxies:
store_proxy_stats(proxy, result.proxy_timing)
store_proxy_stats(proxy, result.time.proxy)

return result
20 changes: 10 additions & 10 deletions colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def _execute_preprocess(task: ExecutableTask, result: Result) -> Tuple[Result, P
result.mark_compute_started()

# Unpack the inputs
result.time_deserialize_inputs = result.deserialize()
result.time.deserialize_inputs = result.deserialize()

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
resolve_proxies_async(result.args)
resolve_proxies_async(result.kwargs)
result.time_async_resolve_proxies = perf_counter() - start_time
result.time.async_resolve_proxies = perf_counter() - start_time

# Create a temporary directory
# TODO (wardlt): Figure out how to allow users to define a path for temporary directories
Expand All @@ -76,7 +76,7 @@ def _execute_preprocess(task: ExecutableTask, result: Result) -> Tuple[Result, P
end_time = perf_counter()

# Record the time required to perform the pre-processing
result.additional_timing['exec_preprocess'] = end_time - start_time
result.time.additional['exec_preprocess'] = end_time - start_time

# Remove the inputs. We don't need to send them back to the manager (the manager already knows what it sent out)
result.inputs = ((), {})
Expand Down Expand Up @@ -108,23 +108,23 @@ def _execute_postprocess(task: ExecutableTask, exit_code: int, result: Result, t
result.failure_info = FailureInformation.from_exception(e)
finally:
end_time = perf_counter()
result.additional_timing['exec_postprocess'] = end_time - start_time
result.time.additional['exec_postprocess'] = end_time - start_time

# Store the results
if result.success:
result.set_result(output, datetime.now().timestamp() - result.time_compute_started)
result.set_result(output, datetime.now().timestamp() - result.timestamp.compute_started)

# Store the run time in the result object
result.additional_timing['exec_execution'] = (result.time_running -
result.additional_timing['exec_postprocess'] -
result.additional_timing['exec_preprocess'])
result.time.additional['exec_execution'] = (result.time.running -
result.time.additional['exec_postprocess'] -
result.time.additional['exec_preprocess'])

# Add the worker information into the tasks, if available
worker_info = {'hostname': platform.node()}
result.worker_info = worker_info

# Re-pack the results (will use proxystore, if able)
result.time_serialize_results, _ = result.serialize()
result.time.serialize_results, _ = result.serialize()

# Put the serialized inputs back, if desired
if result.keep_inputs:
Expand Down Expand Up @@ -222,7 +222,7 @@ def _preprocess_callback(
result.inputs = serialized_inputs

# Store the time it took to run the preprocessing
result.time_running = result.additional_timing.get('exec_preprocess', 0)
result.time.running = result.time.additional.get('exec_preprocess', 0)
return task_server.queues.send_result(result, topic)

# If successful, submit the execute step and pass its result to Parsl
Expand Down
14 changes: 7 additions & 7 deletions colmena/task_server/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def test_run_function(store):
run_and_record_timing(lambda x: x.upper(), result)

# Make sure the timings are all set
assert result.time_running > 0
assert result.time_async_resolve_proxies > 0
assert result.time_deserialize_inputs > 0
assert result.time_serialize_results > 0
assert result.time_compute_ended > result.time_compute_started
assert result.time.running > 0
assert result.time.async_resolve_proxies > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_ended > result.timestamp.compute_started

# Make sure we have stats for both proxies
assert len(result.proxy_timing) == 2
assert all('store.proxy' in v['times'] for v in result.proxy_timing.values())
assert len(result.time.proxy) == 2
assert all('store.proxy' in v['times'] for v in result.time.proxy.values())
26 changes: 13 additions & 13 deletions colmena/task_server/tests/test_parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def count_nodes(x, _resources: ResourceRequirements):
def config(tmpdir):
return Config(
executors=[
HighThroughputExecutor(max_workers=1)
HighThroughputExecutor(max_workers=1, address='127.0.0.1')
],
strategy=None,
run_dir=str(tmpdir / 'run'),
Expand Down Expand Up @@ -93,11 +93,11 @@ def test_run_simple(server_and_queue):
result = queue.get_result()
assert result.success
assert result.value == 2
assert result.time_running > 0
assert result.time_deserialize_inputs > 0
assert result.time_serialize_results > 0
assert result.time_compute_started is not None
assert result.time_result_sent is not None
assert result.time.running > 0
assert result.time.deserialize_inputs > 0
assert result.time.serialize_results > 0
assert result.timestamp.compute_started is not None
assert result.timestamp.result_sent is not None


@mark.timeout(30)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_error_handling(server_and_queue):
assert result.value is None
assert not result.success
assert result.failure_info is not None
assert result.time_running is not None
assert result.time.running is not None

# Send a task that kills the worker
queue.send_inputs(None, method='bad_task')
Expand All @@ -146,7 +146,7 @@ def test_bash(server_and_queue):
assert result.success, result.failure_info
assert result.value == '1\n'
assert result.keep_inputs
assert result.additional_timing['exec_execution'] > 0
assert result.time.additional['exec_execution'] > 0
assert result.inputs == ((1,), {})

# Send an MPI task
Expand All @@ -155,15 +155,15 @@ def test_bash(server_and_queue):
assert result.success, result.failure_info
assert result.value == '-N 1 -n 1 --cc depth echo -n 1\n' # We're actually testing that it makes the correct command string
assert result.keep_inputs
assert result.additional_timing['exec_execution'] > 0
assert result.time.additional['exec_execution'] > 0
assert result.inputs == ((1,), {})

# Send an MPI task
queue.send_inputs(1, method='fakempitask', keep_inputs=True, resources=ResourceRequirements(node_count=2, cpu_processes=4))
result = queue.get_result()
assert result.success, result.failure_info
assert result.value == '-N 8 -n 4 --cc depth echo -n 1\n'
assert result.additional_timing['exec_execution'] > 0
assert result.time.additional['exec_execution'] > 0
assert result.inputs == ((1,), {})


Expand Down Expand Up @@ -201,22 +201,22 @@ def test_proxy(server_and_queue, store):
queue.send_inputs([little_string], big_string, method='capitalize')
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 3 # There are two proxies to resolve, one is created
assert len(result.time.proxy) == 3 # There are two proxies to resolve, one is created

# Proxy the results ahead of time
little_proxy = store.proxy(little_string)

queue.send_inputs([little_proxy], big_string, method='capitalize')
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 3
assert len(result.time.proxy) == 3

# Try it with a kwarg
queue.send_inputs(['a'], big_string, input_kwargs={'little': little_proxy}, method='capitalize',
keep_inputs=False) # TODO (wardlt): test does not work with keep-inputs=True
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 3
assert len(result.time.proxy) == 3


@mark.timeout(10)
Expand Down

0 comments on commit abe3084

Please sign in to comment.