Skip to content
38 changes: 33 additions & 5 deletions state-manager/app/controller/executed_state.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel
from bson import ObjectId
from fastapi import HTTPException, status
from fastapi import HTTPException, status, BackgroundTasks

from app.models.db.state import State
from app.models.state_status_enum import StateStatusEnum
from app.singletons.logs_manager import LogsManager
from app.tasks.create_next_state import create_next_state

logger = LogsManager().get_logger()

async def executed_state(namespace_name: str, state_id: ObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str) -> ExecutedResponseModel:
async def executed_state(namespace_name: str, state_id: ObjectId, body: ExecutedRequestModel, x_exosphere_request_id: str, background_tasks: BackgroundTasks) -> ExecutedResponseModel:

try:
logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
Expand All @@ -20,9 +21,36 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed
if state.status != StateStatusEnum.QUEUED:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued")

await State.find_one(State.id == state_id).set(
{"status": StateStatusEnum.EXECUTED, "outputs": body.outputs}
)
if len(body.outputs) == 0:
await State.find_one(State.id == state_id).set(
{"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: ObjectId(state.id)}}
)

background_tasks.add_task(create_next_state, state)

else:
await State.find_one(State.id == state_id).set(
{"status": StateStatusEnum.EXECUTED, "outputs": body.outputs[0], "parents": {**state.parents, state.identifier: ObjectId(state.id)}}
)
background_tasks.add_task(create_next_state, state)

for output in body.outputs[1:]:
new_state = State(
node_name=state.node_name,
namespace_name=state.namespace_name,
identifier=state.identifier,
graph_name=state.graph_name,
status=StateStatusEnum.CREATED,
inputs=state.inputs,
outputs=output,
error=None,
parents={
**state.parents,
state.identifier: ObjectId(state.id)
}
)
await new_state.save()
background_tasks.add_task(create_next_state, new_state)

return ExecutedResponseModel(status=StateStatusEnum.EXECUTED)

Expand Down
4 changes: 3 additions & 1 deletion state-manager/app/models/db/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from bson import ObjectId
from .base import BaseDatabaseModel
from ..state_status_enum import StateStatusEnum
from pydantic import Field
Expand All @@ -13,4 +14,5 @@ class State(BaseDatabaseModel):
status: StateStatusEnum = Field(..., description="Status of the state")
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
outputs: dict[str, Any] = Field(..., description="Outputs of the state")
error: Optional[str] = Field(None, description="Error message")
error: Optional[str] = Field(None, description="Error message")
parents: dict[str, ObjectId] = Field(default_factory=dict, description="Parents of the state")
4 changes: 2 additions & 2 deletions state-manager/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def create_state(namespace_name: str, graph_name: str, body: CreateRequest
response_description="State executed successfully",
tags=["state"]
)
async def executed_state_route(namespace_name: str, state_id: str, body: ExecutedRequestModel, request: Request, api_key: str = Depends(check_api_key)):
async def executed_state_route(namespace_name: str, state_id: str, body: ExecutedRequestModel, request: Request, background_tasks: BackgroundTasks, api_key: str = Depends(check_api_key)):

x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4()))

Expand All @@ -91,7 +91,7 @@ async def executed_state_route(namespace_name: str, state_id: str, body: Execute
logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return await executed_state(namespace_name, ObjectId(state_id), body, x_exosphere_request_id)
return await executed_state(namespace_name, ObjectId(state_id), body, x_exosphere_request_id, background_tasks)


@router.post(
Expand Down
124 changes: 124 additions & 0 deletions state-manager/app/tasks/create_next_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import asyncio
import time

from bson import ObjectId

from app.models.db.state import State
from app.models.db.graph_template_model import GraphTemplate
from app.models.graph_template_validation_status import GraphTemplateValidationStatus
from app.models.db.registered_node import RegisteredNode
from app.models.state_status_enum import StateStatusEnum

from json_schema_to_pydantic import create_model

async def create_next_state(state: State):
graph_template = None

try:
start_time = time.time()
timeout_seconds = 300 # 5 minutes

while True:
graph_template = await GraphTemplate.find_one(GraphTemplate.name == state.graph_name, GraphTemplate.namespace == state.namespace_name)
if not graph_template:
raise Exception(f"Graph template {state.graph_name} not found")
if graph_template.validation_status == GraphTemplateValidationStatus.VALID:
break

# Check if we've exceeded the timeout
if time.time() - start_time > timeout_seconds:
raise Exception(f"Timeout waiting for graph template {state.graph_name} to become valid after {timeout_seconds} seconds")

await asyncio.sleep(1)

node_template = graph_template.get_node_by_identifier(state.identifier)
if not node_template:
raise Exception(f"Node template {state.identifier} not found")

next_node_identifier = node_template.next_nodes
if not next_node_identifier:
raise Exception(f"Node template {state.identifier} has no next nodes")

cache_states = {}

for identifier in next_node_identifier:
next_node_template = graph_template.get_node_by_identifier(identifier)
if not next_node_template:
continue

registered_node = await RegisteredNode.find_one(RegisteredNode.name == next_node_template.node_name, RegisteredNode.namespace == next_node_template.namespace)

if not registered_node:
raise Exception(f"Registered node {next_node_template.node_name} not found")

next_node_input_model = create_model(registered_node.inputs_schema)
next_node_input_data = {}

for field_name, _ in next_node_input_model.model_fields.items():
temporary_input = next_node_template.inputs[field_name]
splits = temporary_input.split("${{")

if len(splits) == 0:
next_node_input_data[field_name] = temporary_input
continue

constructed_string = ""
for split in splits:
if "}}" in split:
placeholder_content = split.split("}}")[0]
parts = [p.strip() for p in placeholder_content.split('.')]

if len(parts) != 3 or parts[1] != 'outputs':
raise Exception(f"Invalid input placeholder format: '{placeholder_content}' for field {field_name}")

input_identifier = parts[0]
input_field = parts[2]

parent_id = state.parents.get(input_identifier)

if not parent_id:
raise Exception(f"Parent identifier '{input_identifier}' not found in state parents.")

if parent_id not in cache_states:
dependent_state = await State.get(ObjectId(parent_id))
if not dependent_state:
raise Exception(f"Dependent state {input_identifier} not found")
cache_states[parent_id] = dependent_state
else:
dependent_state = cache_states[parent_id]

if input_field not in dependent_state.outputs:
raise Exception(f"Input field {input_field} not found in dependent state {input_identifier}")

constructed_string += dependent_state.outputs[input_field] + split.split("}}")[1]

else:
constructed_string += split

next_node_input_data[field_name] = constructed_string

new_state = State(
node_name=next_node_template.node_name,
namespace_name=next_node_template.namespace,
identifier=next_node_template.identifier,
graph_name=state.graph_name,
status=StateStatusEnum.CREATED,
inputs=next_node_input_data,
outputs={},
error=None,
parents={
**state.parents,
next_node_template.identifier: ObjectId(state.id)
}
)

await new_state.save()

state.status = StateStatusEnum.SUCCESS
await state.save()

except Exception as e:
state.status = StateStatusEnum.ERRORED
state.error = str(e)
await state.save()
return