Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions lib/crewai/src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,25 @@ def __new__(
continue
if attr_name in parent_fields:
annotations[attr_name] = Any
parent_field = next(
(
base.model_fields.get(attr_name)
for base in bases
if hasattr(base, "model_fields")
and base.model_fields.get(attr_name) is not None
),
None,
)
field_kwargs = {}
if parent_field is not None:
field_kwargs["exclude"] = parent_field.exclude
if isinstance(attr_value, BaseModel):
namespace[attr_name] = Field(
default_factory=lambda v=attr_value: v, exclude=True
default_factory=lambda v=attr_value: v,
**field_kwargs,
)
elif parent_field is not None:
namespace[attr_name] = Field(default=attr_value, **field_kwargs)
continue
if callable(attr_value) or isinstance(
attr_value, (*_skip_types, FlowMethod)
Expand Down Expand Up @@ -908,7 +923,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):

entity_type: Literal["flow"] = "flow"

initial_state: Any = Field(default=None)
initial_state: Any = Field(default=None, exclude=True)
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
Expand Down
135 changes: 45 additions & 90 deletions lib/crewai/src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,35 +611,7 @@ async def _aexecute_core(

self._post_agent_execution(agent)

if isinstance(result, BaseModel):
raw = result.model_dump_json()
if self.output_pydantic:
pydantic_output = result
json_output = None
elif self.output_json:
pydantic_output = None
json_output = result.model_dump()
else:
pydantic_output = None
json_output = None
elif not self._guardrails and not self._guardrail:
raw = result
pydantic_output, json_output = self._export_output(result)
else:
raw = result
pydantic_output, json_output = None, None

task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=raw,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
task_output = self._build_task_output(result=result, agent=agent)

if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
Expand Down Expand Up @@ -680,10 +652,12 @@ async def _aexecute_core(

if self.output_file:
content = (
json_output
if json_output
task_output.json_dict
if task_output.json_dict
else (
pydantic_output.model_dump_json() if pydantic_output else result
task_output.pydantic.model_dump_json()
if task_output.pydantic
else result
)
)
self._save_file(content)
Expand Down Expand Up @@ -733,35 +707,7 @@ def _execute_core(

self._post_agent_execution(agent)

if isinstance(result, BaseModel):
raw = result.model_dump_json()
if self.output_pydantic:
pydantic_output = result
json_output = None
elif self.output_json:
pydantic_output = None
json_output = result.model_dump()
else:
pydantic_output = None
json_output = None
elif not self._guardrails and not self._guardrail:
raw = result
pydantic_output, json_output = self._export_output(result)
else:
raw = result
pydantic_output, json_output = None, None

task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=raw,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
task_output = self._build_task_output(result=result, agent=agent)

if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
Expand Down Expand Up @@ -803,10 +749,12 @@ def _execute_core(

if self.output_file:
content = (
json_output
if json_output
task_output.json_dict
if task_output.json_dict
else (
pydantic_output.model_dump_json() if pydantic_output else result
task_output.pydantic.model_dump_json()
if task_output.pydantic
else result
)
)
self._save_file(content)
Expand Down Expand Up @@ -1079,6 +1027,37 @@ def _export_output(

return pydantic_output, json_output

def _build_task_output(self, result: Any, agent: BaseAgent) -> TaskOutput:
if isinstance(result, BaseModel):
raw = result.model_dump_json()
if self.output_pydantic:
pydantic_output = result
json_output = None
elif self.output_json:
pydantic_output = None
json_output = result.model_dump()
else:
pydantic_output = None
json_output = None
elif not self._guardrails and not self._guardrail:
raw = result
pydantic_output, json_output = self._export_output(result)
else:
raw = result
pydantic_output, json_output = None, None

return TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=raw,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)

def _get_output_format(self) -> OutputFormat:
if self.output_json:
return OutputFormat.JSON
Expand Down Expand Up @@ -1240,19 +1219,7 @@ def _invoke_guardrail_function(
context=context,
tools=tools,
)

pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
task_output = self._build_task_output(result=result, agent=agent)

return task_output

Expand Down Expand Up @@ -1336,18 +1303,6 @@ async def _ainvoke_guardrail_function(
context=context,
tools=tools,
)

pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
task_output = self._build_task_output(result=result, agent=agent)

return task_output
46 changes: 45 additions & 1 deletion lib/crewai/tests/task/test_async_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from crewai.agent import Agent
from crewai.task import Task
from pydantic import BaseModel
from crewai.tasks.task_output import TaskOutput
from crewai.tasks.output_format import OutputFormat

Expand Down Expand Up @@ -285,6 +286,49 @@ def guardrail_fn(output: TaskOutput) -> tuple[bool, str]:
assert result is not None
assert call_count == 2

@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_retry_preserves_pydantic_output(
self, mock_execute: AsyncMock, test_agent: Agent
) -> None:
"""Test async guardrail retries preserve structured BaseModel outputs."""

class FamilyList(BaseModel):
families: list[str]

mock_execute.side_effect = [
FamilyList(families=["first"]),
FamilyList(families=["second"]),
]
call_count = 0

def guardrail_fn(
output: TaskOutput,
) -> tuple[bool, str | TaskOutput]:
nonlocal call_count
call_count += 1
assert isinstance(output.pydantic, FamilyList)
if call_count == 1:
return False, "Try again with a corrected family list"
return True, output

task = Task(
description="Test task",
expected_output="Test output",
agent=test_agent,
output_pydantic=FamilyList,
guardrail=guardrail_fn,
guardrail_max_retries=1,
)

result = await task.aexecute_sync()

assert result is not None
assert call_count == 2
assert isinstance(result.pydantic, FamilyList)
assert result.pydantic.families == ["second"]
assert result.raw == '{"families":["second"]}'

@pytest.mark.asyncio
@patch("crewai.Agent.aexecute_task", new_callable=AsyncMock)
async def test_ainvoke_guardrail_max_retries_exceeded(
Expand Down Expand Up @@ -383,4 +427,4 @@ async def test_aexecute_sync_task_output_attributes(
assert result.description == "Test description"
assert result.expected_output == "Test expected"
assert result.raw == "Test result"
assert result.agent == "Test Agent"
assert result.agent == "Test Agent"
36 changes: 35 additions & 1 deletion lib/crewai/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
from unittest.mock import MagicMock, patch

import pytest
from pydantic import BaseModel

from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.flow.flow import Flow, start
from crewai.flow.flow import Flow, FlowState, start
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import (
_find_checkpoint,
_resolve,
_SENTINEL,
_do_checkpoint,
)
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.provider.sqlite_provider import SqliteProvider
Expand Down Expand Up @@ -187,6 +189,38 @@ def test_trigger_events(self) -> None:
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}


def test_flow_checkpoint_handles_nested_pydantic_models(tmp_path) -> None:
class Family(BaseModel):
family_id: int
name: str

class FamilyState(FlowState):
families: list[Family] = []

class FamilyFlow(Flow[FamilyState]):
initial_state = FamilyState

@start()
def create_family(self) -> None:
self.state.families = [Family(family_id=1, name="Smith")]

checkpoint_dir = str(tmp_path / "checkpoints")
cfg = CheckpointConfig(
location=checkpoint_dir,
on_events=["method_execution_finished"],
provider=JsonProvider(),
)
flow = FamilyFlow()
flow.kickoff()

state = RuntimeState(root=[flow])
_do_checkpoint(state, cfg)
branch_dir = os.path.join(checkpoint_dir, "main")
assert os.path.isdir(branch_dir)
files = [name for name in os.listdir(branch_dir) if name.endswith(".json")]
assert files


# ---------- RuntimeState lineage ----------


Expand Down
Loading