Skip to content

Commit

Permalink
Return proxies during serialization
Browse files Browse the repository at this point in the history
The proxy is immediately created then serialized,
so we need to return it so that the stats can be accessible
  • Loading branch information
WardLT committed Feb 6, 2023
1 parent fd61c13 commit b67c041
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 22 deletions.
10 changes: 7 additions & 3 deletions colmena/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,17 @@ def set_result(self, result: Any, runtime: float = None):
self.time_running = runtime
self.success = True

def serialize(self) -> float:
def serialize(self) -> Tuple[float, List[Proxy]]:
"""Stores the input and value fields as a pickled objects
Returns:
(float) Time to serialize
- (float) Time to serialize
- List of any proxies that were created
"""
start_time = perf_counter()
_value = self.value
_inputs = self.inputs
proxies = []

def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
"""Helper function for serializing and proxying
Expand Down Expand Up @@ -311,6 +313,8 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
)
value_proxy = store.proxy(value, evict=evict)
logger.debug(f'Proxied object of type {type(value)} with id={id(value)}')
proxies.append(value_proxy)

# Serialize the proxy with Colmena's utilities. This is
# efficient since the proxy is just a reference and metadata
value_str = SerializationMethod.serialize(
Expand Down Expand Up @@ -345,7 +349,7 @@ def _serialize_and_proxy(value, evict=False) -> Tuple[str, int]:
if 'value' not in self.message_sizes:
self.message_sizes['value'] = value_size

return perf_counter() - start_time
return perf_counter() - start_time, proxies
except Exception as e:
# Put the original values back
self.inputs = _inputs
Expand Down
6 changes: 3 additions & 3 deletions colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def send_inputs(self,
proxystore_kwargs.update({
'proxystore_name': self.proxystore_name[topic],
'proxystore_threshold': self.proxystore_threshold[topic],
# Pydantic prefers to not have types as attributes so we
# Pydantic prefers to not have types as attributes, so we
# get the string corresponding to the type of the store we use
'proxystore_type': get_class_path(type(store)),
'proxystore_kwargs': store.kwargs
Expand All @@ -239,9 +239,9 @@ def send_inputs(self,
)

# Push the serialized value to the task server
result.time_serialize_inputs = result.serialize()
result.time_serialize_inputs, proxies = result.serialize()
self._send_request(result.json(exclude_none=True), topic)
logger.info(f'Client sent a {method} task with topic {topic}')
logger.info(f'Client sent a {method} task with topic {topic}. Created {len(proxies)} proxies for input values')

# Store the task ID in the active list
with self._active_lock:
Expand Down
14 changes: 5 additions & 9 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
proxies = []
input_proxies = []
for arg in result.args:
proxies.extend(resolve_proxies_async(arg))
input_proxies.extend(resolve_proxies_async(arg))
for value in result.kwargs.values():
proxies.extend(resolve_proxies_async(value))
input_proxies.extend(resolve_proxies_async(value))
result.time_async_resolve_proxies = perf_counter() - start_time

# Execute the function
Expand Down Expand Up @@ -215,14 +215,10 @@ def run_and_record_timing(func: Callable, result: Result) -> Result:
result.mark_compute_ended()

# Re-pack the results
result.time_serialize_results = result.serialize()

# If the result was proxied, add it of the list of proxies to get stats for
if isinstance(result.value, proxystore.proxy.Proxy):
proxies.append(result.value)
result.time_serialize_results, output_proxies = result.serialize()

# Get the statistics for the proxy resolution
for proxy in proxies:
for proxy in input_proxies + output_proxies:
# Get the key associated with this proxy
key = proxystore.store.utils.get_key(proxy)

Expand Down
2 changes: 1 addition & 1 deletion colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _execute_postprocess(task: ExecutableTask, exit_code: int, result: Result, t
result.worker_info = worker_info

# Re-pack the results (will use proxystore, if able)
result.time_serialize_results = result.serialize()
result.time_serialize_results, _ = result.serialize()

# Put the serialized inputs back, if desired
if result.keep_inputs:
Expand Down
46 changes: 44 additions & 2 deletions colmena/task_server/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict, Tuple, List, Optional
from pathlib import Path

from colmena.models import Result, ExecutableTask
from proxystore.store import unregister_store
from proxystore.store.file import FileStore
from pytest import fixture

# TODO (wardlt): Figure how to import this from test_models
from colmena.models import Result, ExecutableTask, SerializationMethod
from colmena.task_server.base import run_and_record_timing


# TODO (wardlt): Figure how to import this from test_models
class EchoTask(ExecutableTask):
def __init__(self):
super().__init__(executable=['echo'])
Expand Down Expand Up @@ -37,3 +40,42 @@ def test_run_with_executable():
run_and_record_timing(func, result)
result.deserialize()
assert result.value == '1\n'


@fixture
def store(tmpdir):
store = FileStore(name='store', store_dir=tmpdir, stats=True)
yield store
unregister_store('store')


def test_run_function(store):
"""Make sure the run function behaves as expected:
- Records runtimes
- Tracks proxy statistics
"""

# Make the result and configure it to use the store
result = Result(inputs=(('a' * 1024,), {}))
result.proxystore_name = store.name
result.proxystore_type = f'{store.__class__.__module__}.{store.__class__.__name__}'
result.proxystore_threshold = 128
result.proxystore_kwargs = store.kwargs

# Serialize it
result.serialization_method = SerializationMethod.PICKLE
result.serialize()

# Run the function
run_and_record_timing(lambda x: x.upper(), result)

# Make sure the timings are all set
assert result.time_running > 0
assert result.time_async_resolve_proxies > 0
assert result.time_deserialize_inputs > 0
assert result.time_serialize_results > 0
assert result.time_compute_ended > result.time_compute_started

# Make sure we have stats for both proxies
assert len(result.proxy_timing) == 2
8 changes: 4 additions & 4 deletions colmena/task_server/tests/test_parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def store():
store = RedisStore('store', hostname='localhost', port=6379, stats=True)
proxystore.store.register_store(store)
yield store
proxystore.store.unregister_store(store)
proxystore.store.unregister_store(store.name)
store.close()


Expand Down Expand Up @@ -198,19 +198,19 @@ def test_proxy(server_and_queue, store):
queue.send_inputs([little_string], big_string, method='capitalize')
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 2 # There are two proxies to resolve
assert len(result.proxy_timing) == 3 # There are two proxies to resolve, one is created

# Proxy the results ahead of time
little_proxy = store.proxy(little_string)

queue.send_inputs([little_proxy], big_string, method='capitalize')
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 2
assert len(result.proxy_timing) == 3

# Try it with a kwarg
queue.send_inputs(['a'], big_string, input_kwargs={'little': little_proxy}, method='capitalize',
keep_inputs=False) # TODO (wardlt): test does not work with keep-inputs=True
result = queue.get_result()
assert result.success, result.failure_info.exception
assert len(result.proxy_timing) == 2
assert len(result.proxy_timing) == 3

0 comments on commit b67c041

Please sign in to comment.