Skip to content

Commit

Permalink
Implement a basic system for reproducibly resolving references during…
Browse files Browse the repository at this point in the history
… runtime (#196)
  • Loading branch information
HughIdiyit authored Jun 18, 2024
1 parent f85c112 commit 5db6af9
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 19 deletions.
8 changes: 8 additions & 0 deletions runtime/hetdesrun/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
from copy import deepcopy
from posixpath import join as posix_urljoin
from uuid import UUID, uuid4

Expand All @@ -27,6 +28,9 @@
)
from hetdesrun.persistence.models.transformation import TransformationRevision
from hetdesrun.persistence.models.workflow import WorkflowContent
from hetdesrun.reference_context import (
set_reproducibility_reference_context,
)
from hetdesrun.runtime.logging import execution_context_filter
from hetdesrun.runtime.service import runtime_service
from hetdesrun.utils import Type
Expand Down Expand Up @@ -292,6 +296,10 @@ async def execute_transformation_revision(

execution_context_filter.bind_context(job_id=exec_by_id_input.job_id)

# Set the reproducibility reference context to the provided reference of the exec object
repr_reference = deepcopy(exec_by_id_input.resolved_reproducibility_references)
set_reproducibility_reference_context(repr_reference)

# prepare execution input

prep_exec_input_measured_step = PerformanceMeasuredStep.create_and_begin(
Expand Down
6 changes: 6 additions & 0 deletions runtime/hetdesrun/backend/models/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from hetdesrun.backend.service.utils import to_camel
from hetdesrun.datatypes import DataType
from hetdesrun.models.repr_reference import ReproducibilityReference
from hetdesrun.models.run import WorkflowExecutionInfo
from hetdesrun.persistence.models.transformation import TransformationRevision
from hetdesrun.utils import State, Type
Expand Down Expand Up @@ -48,6 +49,11 @@ class ExecutionResponseFrontendDto(WorkflowExecutionInfo):
result: str
output_results_by_output_name: dict[str, Any] = {}
output_types_by_output_name: dict[str, DataType] = {}
resolved_reproducibility_references: ReproducibilityReference = Field(
default_factory=ReproducibilityReference,
description="Resolved references to information needed to reproduce an execution result."
"The provided data can be used to replace data that would usually be produced at runtime.",
)
process_id: int | None = Field(
None,
description=(
Expand Down
9 changes: 9 additions & 0 deletions runtime/hetdesrun/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from pydantic import BaseModel, Field

from hetdesrun.models.repr_reference import ReproducibilityReference
from hetdesrun.models.wiring import WorkflowWiring
from hetdesrun.reference_context import (
get_deepcopy_of_reproducibility_reference_context,
)


class ExecByIdBase(BaseModel):
Expand All @@ -12,6 +16,11 @@ class ExecByIdBase(BaseModel):
description="The wiring to be used. "
"If no wiring is provided the stored test wiring will be used.",
)
resolved_reproducibility_references: ReproducibilityReference = Field(
default_factory=get_deepcopy_of_reproducibility_reference_context,
description="Resolved references to information needed to reproduce an execution result."
"The provided data can be used to replace data that would usually be produced at runtime.",
)
run_pure_plot_operators: bool = Field(
False, description="Whether pure plot components should be run."
)
Expand Down
20 changes: 20 additions & 0 deletions runtime/hetdesrun/models/repr_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import datetime

from pydantic import BaseModel, Field, validator


class ReproducibilityReference(BaseModel):
exec_start_timestamp: datetime.datetime | None = Field(
None, description="UTC-Timestamp referencing the start time of an execution."
)

# TODO With Pydantic V2 this can probably be solved using AwareDatetime
# instead of a custom validator
@validator("exec_start_timestamp")
def ensure_utc(cls, ts: datetime.datetime | None) -> datetime.datetime | None:
if ts is not None:
if ts.tzinfo is None:
raise ValueError("The execution start timestamp must be timezone-aware")
if ts.tzinfo != datetime.timezone.utc:
raise ValueError("The execution start timestamp must be in UTC")
return ts
60 changes: 41 additions & 19 deletions runtime/hetdesrun/models/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Models for runtime execution endpoint"""


import datetime
import traceback as tb
from enum import Enum, StrEnum
Expand All @@ -13,8 +12,12 @@
from hetdesrun.models.base import Result
from hetdesrun.models.code import CodeModule, NonEmptyValidStr, ShortNonEmptyValidStr
from hetdesrun.models.component import ComponentRevision
from hetdesrun.models.repr_reference import ReproducibilityReference
from hetdesrun.models.wiring import OutputWiring, WorkflowWiring
from hetdesrun.models.workflow import WorkflowNode
from hetdesrun.reference_context import (
get_deepcopy_of_reproducibility_reference_context,
)
from hetdesrun.runtime.exceptions import ComponentException, RuntimeExecutionError
from hetdesrun.utils import Type, check_explicit_utc

Expand Down Expand Up @@ -301,9 +304,11 @@ class WorkflowExecutionError(BaseModel):
def get_location_of_exception(exception: Exception | BaseException) -> ErrorLocation:
last_trace = tb.extract_tb(exception.__traceback__)[-1]
return ErrorLocation(
file=last_trace.filename
if last_trace.filename != "<string>"
else "COMPONENT CODE",
file=(
last_trace.filename
if last_trace.filename != "<string>"
else "COMPONENT CODE"
),
function_name=last_trace.name,
line_number=last_trace.lineno,
)
Expand All @@ -330,23 +335,31 @@ def from_exception(
) -> "WorkflowExecutionInfo":
return WorkflowExecutionInfo(
error=WorkflowExecutionError(
type=type(exception).__name__
if cause is None
else type(cause).__name__,
type=(
type(exception).__name__ if cause is None else type(cause).__name__
),
message=str(exception) if cause is None else str(cause),
extra_information=exception.extra_information
if isinstance(exception, ComponentException)
else None,
error_code=exception.error_code
if isinstance(exception, ComponentException)
else None,
extra_information=(
exception.extra_information
if isinstance(exception, ComponentException)
else None
),
error_code=(
exception.error_code
if isinstance(exception, ComponentException)
else None
),
process_stage=process_stage,
operator_info=OperatorInfo.from_runtime_execution_error(exception)
if isinstance(exception, RuntimeExecutionError)
else None,
location=get_location_of_exception(exception)
if cause is None
else get_location_of_exception(cause),
operator_info=(
OperatorInfo.from_runtime_execution_error(exception)
if isinstance(exception, RuntimeExecutionError)
else None
),
location=(
get_location_of_exception(exception)
if cause is None
else get_location_of_exception(cause)
),
),
traceback=tb.format_exc(),
output_results_by_output_name={},
Expand All @@ -371,6 +384,11 @@ class WorkflowExecutionResult(WorkflowExecutionInfo):
" set to true."
),
)
resolved_reproducibility_references: ReproducibilityReference = Field(
default_factory=get_deepcopy_of_reproducibility_reference_context,
description="Resolved references to information needed to reproduce an execution result."
"The provided data can be used to replace data that would usually be produced at runtime.",
)

@classmethod
def from_exception(
Expand All @@ -381,8 +399,12 @@ def from_exception(
cause: BaseException | None = None,
node_results: str | None = None,
) -> "WorkflowExecutionResult":
# Access the current context to retrieve resolved reproducibility references
repr_reference = get_deepcopy_of_reproducibility_reference_context()

return WorkflowExecutionResult(
**super().from_exception(exception, process_stage, job_id, cause).dict(),
result="failure",
node_results=node_results,
resolved_reproducibility_references=repr_reference,
)
26 changes: 26 additions & 0 deletions runtime/hetdesrun/reference_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from contextvars import ContextVar
from copy import deepcopy

from hetdesrun.models.repr_reference import ReproducibilityReference

reproducibility_reference_context: ContextVar[ReproducibilityReference] = ContextVar(
"reproducibility_reference_context"
)


def get_reproducibility_reference_context() -> ReproducibilityReference:
try:
return reproducibility_reference_context.get()
except LookupError:
reproducibility_reference_context.set(ReproducibilityReference())
return reproducibility_reference_context.get()


def get_deepcopy_of_reproducibility_reference_context() -> ReproducibilityReference:
return deepcopy(get_reproducibility_reference_context())


def set_reproducibility_reference_context(
new_reference: ReproducibilityReference,
) -> None:
reproducibility_reference_context.set(new_reference)
171 changes: 171 additions & 0 deletions runtime/tests/test_reproducibility_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import json
import logging
from datetime import datetime, timedelta, timezone
from uuid import uuid4

import pytest

from hdutils import DataType
from hetdesrun.backend.execution import execute_transformation_revision
from hetdesrun.backend.models.info import ExecutionResponseFrontendDto
from hetdesrun.models.execution import ExecByIdBase, ExecByIdInput
from hetdesrun.models.repr_reference import ReproducibilityReference
from hetdesrun.models.run import WorkflowExecutionResult
from hetdesrun.persistence.dbservice.revision import (
store_single_transformation_revision,
)
from hetdesrun.persistence.models.transformation import TransformationRevision
from hetdesrun.reference_context import (
get_deepcopy_of_reproducibility_reference_context,
get_reproducibility_reference_context,
set_reproducibility_reference_context,
)


def test_utc_validation():
with pytest.raises(
ValueError, match="The execution start timestamp must be timezone-aware"
):
_ = ReproducibilityReference(
exec_start_timestamp=datetime.now() # noqa: DTZ005
)

with pytest.raises(
ValueError, match="The execution start timestamp must be in UTC"
):
__ = ReproducibilityReference(
exec_start_timestamp=datetime.now(tz=timezone(timedelta(hours=1)))
)


def test_context_var_setting_and_getting():
# Test getter
assert get_reproducibility_reference_context() == ReproducibilityReference()

# Test setter
rr1 = ReproducibilityReference(
exec_start_timestamp=datetime(1949, 5, 23, tzinfo=timezone.utc)
)
set_reproducibility_reference_context(rr1)
assert get_reproducibility_reference_context() == rr1

# Test whether deepcopy getter returns an actual deepcopy
rr2 = get_deepcopy_of_reproducibility_reference_context()
assert rr2 == rr1
assert rr2 is not rr1
assert rr2.exec_start_timestamp == rr1.exec_start_timestamp


def test_default_factories():
exec_resp_frontend = ExecutionResponseFrontendDto(
result="nf",
output_results_by_output_name={"nf": 23},
output_types_by_output_name={"nf": DataType.Integer},
job_id=uuid4(),
)
exec_by_id_obj = ExecByIdBase(id=uuid4())
wf_result = WorkflowExecutionResult(
result="failure", output_results_by_output_name={"nf": 23}, job_id=uuid4()
)

# Check that at points where marshalling is done
# a (deep) copy is created.
assert exec_resp_frontend.resolved_reproducibility_references is not None
assert (
exec_resp_frontend.resolved_reproducibility_references
is not get_reproducibility_reference_context()
)

assert exec_by_id_obj.resolved_reproducibility_references is not None
assert (
exec_by_id_obj.resolved_reproducibility_references
is not get_reproducibility_reference_context()
)

assert wf_result.resolved_reproducibility_references is not None
assert (
wf_result.resolved_reproducibility_references
is not get_reproducibility_reference_context()
)


@pytest.fixture()
def _db_with_two_trafos(mocked_clean_test_db_session):
# Load a regular transformation revision with state RELEASED
with open(
"transformations/components/connectors/pass-through-string_100_2b1b474f-ddf5-1f4d-fec4-17ef9122112b.json"
) as f:
trafo_data = json.load(f)
store_single_transformation_revision(TransformationRevision(**trafo_data))

# Load a transformation revision to provoke an Error
with open(
"tests/data/components/raise-exception_010_c4dbcc42-eaec-4587-a362-ce6567f21d92.json"
) as f:
trafo_data = json.load(f)
store_single_transformation_revision(TransformationRevision(**trafo_data))


@pytest.mark.asyncio
async def test_for_reference_in_response(_db_with_two_trafos): # noqa: PT019
rr = ReproducibilityReference(
exec_start_timestamp=datetime(1949, 5, 23, tzinfo=timezone.utc)
)

exec_by_id_input = ExecByIdInput(
id="2b1b474f-ddf5-1f4d-fec4-17ef9122112b",
wiring={
"input_wirings": [
{
"adapter_id": "direct_provisioning",
"filters": {"value": "Test exec"},
"use_default_value": False,
"workflow_input_name": "input",
}
],
"output_wirings": [],
},
run_pure_plot_operators=False,
resolved_reproducibility_references=rr,
)
execution_response = await execute_transformation_revision(exec_by_id_input)

assert execution_response.result == "ok"

assert get_reproducibility_reference_context() == rr

assert (
execution_response.resolved_reproducibility_references
== get_reproducibility_reference_context()
)


@pytest.mark.asyncio
async def test_if_reference_in_response_after_exception(
_db_with_two_trafos, caplog # noqa: PT019
):
rr = ReproducibilityReference(
exec_start_timestamp=datetime(1949, 5, 23, tzinfo=timezone.utc)
)

exec_by_id_input = ExecByIdInput(
id="c4dbcc42-eaec-4587-a362-ce6567f21d92",
run_pure_plot_operators=False,
resolved_reproducibility_references=rr,
)

# Execute a transformation revision that should cause a runtime execution error
with caplog.at_level(logging.INFO):
execution_response = await execute_transformation_revision(exec_by_id_input)

# Test whether the execution failed as planned
assert any(
record.levelname == "INFO" and "Runtime Execution Error" in record.message
for record in caplog.records
)
assert execution_response.result == "failure"

assert (
execution_response.resolved_reproducibility_references
== get_reproducibility_reference_context()
)

0 comments on commit 5db6af9

Please sign in to comment.