Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cadence/_internal/workflow/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@
from cadence.api.v1.common_pb2 import ActivityType
from cadence.api.v1.decision_pb2 import ScheduleActivityTaskDecisionAttributes
from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind
from cadence.client import Client
from cadence.data_converter import DataConverter
from cadence.workflow import WorkflowContext, WorkflowInfo, ResultType, ActivityOptions


class Context(WorkflowContext):
def __init__(
self,
client: Client,
info: WorkflowInfo,
decision_helper: DecisionsHelper,
decision_manager: DecisionManager,
):
self._client = client
self._info = info
self._replay_mode = True
self._replay_current_time_milliseconds: Optional[int] = None
Expand All @@ -30,11 +27,8 @@ def __init__(
def info(self) -> WorkflowInfo:
return self._info

def client(self) -> Client:
return self._client

def data_converter(self) -> DataConverter:
return self._client.data_converter
return self.info().data_converter

async def execute_activity(
self,
Expand Down Expand Up @@ -80,7 +74,7 @@ async def execute_activity(
schedule_attributes = ScheduleActivityTaskDecisionAttributes(
activity_id=activity_id,
activity_type=ActivityType(name=activity),
domain=self._client.domain,
domain=self.info().workflow_domain,
task_list=TaskList(kind=TaskListKind.TASK_LIST_KIND_NORMAL, name=task_list),
input=activity_input,
schedule_to_close_timeout=_round_to_nearest_second(schedule_to_close),
Expand Down
149 changes: 64 additions & 85 deletions cadence/_internal/workflow/decision_events_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.client import Client
from cadence._internal.workflow.history_event_iterator import iterate_history_events


@dataclass
Expand Down Expand Up @@ -55,99 +53,36 @@ class DecisionEventsIterator:
into decision iterations for proper workflow replay and execution.
"""

def __init__(self, decision_task: PollForDecisionTaskResponse, client: Client):
self._client = client
def __init__(
self, decision_task: PollForDecisionTaskResponse, events: List[HistoryEvent]
):
self._decision_task = decision_task
self._events: List[HistoryEvent] = []
self._event_index = 0
self._events: List[HistoryEvent] = events
self._decision_task_started_event: Optional[HistoryEvent] = None
self._next_decision_event_id = 1
self._replay = True
self._replay_current_time_milliseconds: Optional[int] = None
self._initialized = False

@staticmethod
def _is_decision_task_started(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskStarted."""
return hasattr(
event, "decision_task_started_event_attributes"
) and event.HasField("decision_task_started_event_attributes")

@staticmethod
def _is_decision_task_completed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskCompleted."""
return hasattr(
event, "decision_task_completed_event_attributes"
) and event.HasField("decision_task_completed_event_attributes")

@staticmethod
def _is_decision_task_failed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskFailed."""
return hasattr(
event, "decision_task_failed_event_attributes"
) and event.HasField("decision_task_failed_event_attributes")

@staticmethod
def _is_decision_task_timed_out(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskTimedOut."""
return hasattr(
event, "decision_task_timed_out_event_attributes"
) and event.HasField("decision_task_timed_out_event_attributes")

@staticmethod
def _is_marker_recorded(event: HistoryEvent) -> bool:
"""Check if event is MarkerRecorded."""
return hasattr(event, "marker_recorded_event_attributes") and event.HasField(
"marker_recorded_event_attributes"
)

@staticmethod
def _is_decision_task_completion(event: HistoryEvent) -> bool:
"""Check if event is any kind of decision task completion."""
return (
DecisionEventsIterator._is_decision_task_completed(event)
or DecisionEventsIterator._is_decision_task_failed(event)
or DecisionEventsIterator._is_decision_task_timed_out(event)
)

async def _ensure_initialized(self):
"""Initialize events list using the existing iterate_history_events."""
if not self._initialized:
# Use existing iterate_history_events function
events_iterator = iterate_history_events(self._decision_task, self._client)
self._events = [event async for event in events_iterator]
self._initialized = True

# Find first decision task started event
for i, event in enumerate(self._events):
if self._is_decision_task_started(event):
self._event_index = i
break
self._event_index = 0
# Find first decision task started event
for i, event in enumerate(self._events):
if _is_decision_task_started(event):
self._event_index = i
break

async def has_next_decision_events(self) -> bool:
"""Check if there are more decision events to process."""
await self._ensure_initialized()

# Look for the next DecisionTaskStarted event from current position
for i in range(self._event_index, len(self._events)):
if self._is_decision_task_started(self._events[i]):
if _is_decision_task_started(self._events[i]):
return True

return False

async def next_decision_events(self) -> DecisionEvents:
"""
Get the next set of decision events.

This method processes events starting from a DecisionTaskStarted event
until the corresponding DecisionTaskCompleted/Failed/TimedOut event.
"""
await self._ensure_initialized()

# Find next DecisionTaskStarted event
start_index = None
for i in range(self._event_index, len(self._events)):
if self._is_decision_task_started(self._events[i]):
if _is_decision_task_started(self._events[i]):
start_index = i
break

Expand Down Expand Up @@ -182,9 +117,9 @@ async def next_decision_events(self) -> DecisionEvents:
decision_events.events.append(event)

# Categorize the event
if self._is_marker_recorded(event):
if _is_marker_recorded(event):
decision_events.markers.append(event)
elif self._is_decision_task_completion(event):
elif _is_decision_task_completion(event):
# This marks the end of this decision iteration
self._process_decision_completion_event(event, decision_events)
current_index += 1 # Move past this event
Expand All @@ -206,7 +141,7 @@ async def next_decision_events(self) -> DecisionEvents:
# Check directly without calling has_next_decision_events to avoid recursion
has_more = False
for i in range(self._event_index, len(self._events)):
if self._is_decision_task_started(self._events[i]):
if _is_decision_task_started(self._events[i]):
has_more = True
break

Expand Down Expand Up @@ -261,16 +196,16 @@ async def __anext__(self) -> DecisionEvents:
def is_decision_event(event: HistoryEvent) -> bool:
"""Check if an event is a decision-related event."""
return (
DecisionEventsIterator._is_decision_task_started(event)
or DecisionEventsIterator._is_decision_task_completed(event)
or DecisionEventsIterator._is_decision_task_failed(event)
or DecisionEventsIterator._is_decision_task_timed_out(event)
_is_decision_task_started(event)
or _is_decision_task_completed(event)
or _is_decision_task_failed(event)
or _is_decision_task_timed_out(event)
)


def is_marker_event(event: HistoryEvent) -> bool:
"""Check if an event is a marker event."""
return DecisionEventsIterator._is_marker_recorded(event)
return _is_marker_recorded(event)


def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]:
Expand All @@ -279,3 +214,47 @@ def extract_event_timestamp_millis(event: HistoryEvent) -> Optional[int]:
seconds = getattr(event.event_time, "seconds", 0)
return seconds * 1000 if seconds > 0 else None
return None


def _is_decision_task_started(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskStarted."""
return hasattr(event, "decision_task_started_event_attributes") and event.HasField(
"decision_task_started_event_attributes"
)


def _is_decision_task_completed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskCompleted."""
return hasattr(
event, "decision_task_completed_event_attributes"
) and event.HasField("decision_task_completed_event_attributes")


def _is_decision_task_failed(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskFailed."""
return hasattr(event, "decision_task_failed_event_attributes") and event.HasField(
"decision_task_failed_event_attributes"
)


def _is_decision_task_timed_out(event: HistoryEvent) -> bool:
"""Check if event is DecisionTaskTimedOut."""
return hasattr(
event, "decision_task_timed_out_event_attributes"
) and event.HasField("decision_task_timed_out_event_attributes")


def _is_marker_recorded(event: HistoryEvent) -> bool:
"""Check if event is MarkerRecorded."""
return hasattr(event, "marker_recorded_event_attributes") and event.HasField(
"marker_recorded_event_attributes"
)


def _is_decision_task_completion(event: HistoryEvent) -> bool:
"""Check if event is any kind of decision task completion."""
return (
_is_decision_task_completed(event)
or _is_decision_task_failed(event)
or _is_decision_task_timed_out(event)
)
15 changes: 5 additions & 10 deletions cadence/_internal/workflow/workflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator
from cadence._internal.workflow.statemachine.decision_manager import DecisionManager
from cadence.api.v1.decision_pb2 import Decision
from cadence.client import Client
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.workflow import WorkflowInfo

Expand All @@ -21,16 +20,14 @@ class DecisionResult:


class WorkflowEngine:
def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None):
def __init__(self, info: WorkflowInfo, workflow_definition=None):
self._workflow_definition = workflow_definition
self._workflow_instance = None
if workflow_definition:
self._workflow_instance = workflow_definition.cls()
self._decision_manager = DecisionManager()
self._decisions_helper = DecisionsHelper()
self._context = Context(
client, info, self._decisions_helper, self._decision_manager
)
self._context = Context(info, self._decisions_helper, self._decision_manager)
self._is_workflow_complete = False

async def process_decision(
Expand Down Expand Up @@ -65,7 +62,7 @@ async def process_decision(
with self._context._activate():
# Create DecisionEventsIterator for structured event processing
events_iterator = DecisionEventsIterator(
decision_task, self._context.client()
decision_task, self._context.info().workflow_events
)

# Process decision events using iterator-driven approach
Expand Down Expand Up @@ -360,10 +357,8 @@ def _extract_workflow_input(
# Deserialize the input using the client's data converter
try:
# Use from_data method with a single type hint of None (no type conversion)
input_data_list = (
self._context.client().data_converter.from_data(
started_attrs.input, [None]
)
input_data_list = self._context.data_converter().from_data(
started_attrs.input, [None]
)
input_data = input_data_list[0] if input_data_list else None
logger.debug(f"Extracted workflow input: {input_data}")
Expand Down
10 changes: 9 additions & 1 deletion cadence/worker/_decision_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
from typing import Dict, Tuple

from cadence._internal.workflow.history_event_iterator import iterate_history_events
from cadence.api.v1.common_pb2 import Payload
from cadence.api.v1.service_worker_pb2 import (
PollForDecisionTaskResponse,
Expand Down Expand Up @@ -102,13 +103,21 @@ async def _handle_task_implementation(
)
raise KeyError(f"Workflow type '{workflow_type_name}' not found")

# fetch full workflow history
# TODO sticky cache
workflow_events = [
event async for event in iterate_history_events(task, self._client)
]

# Create workflow info and get or create workflow engine from cache
workflow_info = WorkflowInfo(
workflow_type=workflow_type_name,
workflow_domain=self._client.domain,
workflow_id=workflow_id,
workflow_run_id=run_id,
workflow_task_list=self.task_list,
data_converter=self._client.data_converter,
workflow_events=workflow_events,
)

# Use thread-safe cache to get or create workflow engine
Expand All @@ -118,7 +127,6 @@ async def _handle_task_implementation(
if workflow_engine is None:
workflow_engine = WorkflowEngine(
info=workflow_info,
client=self._client,
workflow_definition=workflow_definition,
)
self._workflow_engines[cache_key] = workflow_engine
Expand Down
8 changes: 4 additions & 4 deletions cadence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import timedelta
from typing import (
Callable,
List,
cast,
Optional,
Union,
Expand All @@ -17,7 +18,7 @@
)
import inspect

from cadence.client import Client
from cadence.api.v1.history_pb2 import HistoryEvent
from cadence.data_converter import DataConverter

ResultType = TypeVar("ResultType")
Expand Down Expand Up @@ -169,6 +170,8 @@ class WorkflowInfo:
workflow_id: str
workflow_run_id: str
workflow_task_list: str
workflow_events: List[HistoryEvent]
data_converter: DataConverter


class WorkflowContext(ABC):
Expand All @@ -177,9 +180,6 @@ class WorkflowContext(ABC):
@abstractmethod
def info(self) -> WorkflowInfo: ...

@abstractmethod
def client(self) -> Client: ...

@abstractmethod
def data_converter(self) -> DataConverter: ...

Expand Down
Loading