diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 88457f7aad..45990e48dd 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -793,10 +793,25 @@ def __new__( continue if attr_name in parent_fields: annotations[attr_name] = Any + parent_field = next( + ( + base.model_fields.get(attr_name) + for base in bases + if hasattr(base, "model_fields") + and base.model_fields.get(attr_name) is not None + ), + None, + ) + field_kwargs = {} + if parent_field is not None: + field_kwargs["exclude"] = parent_field.exclude if isinstance(attr_value, BaseModel): namespace[attr_name] = Field( - default_factory=lambda v=attr_value: v, exclude=True + default_factory=lambda v=attr_value: v, + **field_kwargs, ) + elif parent_field is not None: + namespace[attr_name] = Field(default=attr_value, **field_kwargs) continue if callable(attr_value) or isinstance( attr_value, (*_skip_types, FlowMethod) @@ -908,7 +923,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): entity_type: Literal["flow"] = "flow" - initial_state: Any = Field(default=None) + initial_state: Any = Field(default=None, exclude=True) name: str | None = Field(default=None) tracing: bool | None = Field(default=None) stream: bool = Field(default=False) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index e12caa2af5..55c0ae2843 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -611,35 +611,7 @@ async def _aexecute_core( self._post_agent_execution(agent) - if isinstance(result, BaseModel): - raw = result.model_dump_json() - if self.output_pydantic: - pydantic_output = result - json_output = None - elif self.output_json: - pydantic_output = None - json_output = result.model_dump() - else: - pydantic_output = None - json_output = None - elif not self._guardrails and not self._guardrail: - raw = result - pydantic_output, json_output = self._export_output(result) - else: - raw = result - pydantic_output, json_output = None, None - - task_output = TaskOutput( - name=self.name or self.description, - description=self.description, - expected_output=self.expected_output, - raw=raw, - pydantic=pydantic_output, - json_dict=json_output, - agent=agent.role, - output_format=self._get_output_format(), - messages=agent.last_messages, # type: ignore[attr-defined] - ) + task_output = self._build_task_output(result=result, agent=agent) if self._guardrails: for idx, guardrail in enumerate(self._guardrails): @@ -680,10 +652,12 @@ async def _aexecute_core( if self.output_file: content = ( - json_output - if json_output + task_output.json_dict + if task_output.json_dict else ( - pydantic_output.model_dump_json() if pydantic_output else result + task_output.pydantic.model_dump_json() + if task_output.pydantic + else result ) ) self._save_file(content) @@ -733,35 +707,7 @@ def _execute_core( self._post_agent_execution(agent) - if isinstance(result, BaseModel): - raw = result.model_dump_json() - if self.output_pydantic: - pydantic_output = result - json_output = None - elif self.output_json: - pydantic_output = None - json_output = result.model_dump() - else: - pydantic_output = None - json_output = None - elif not self._guardrails and not self._guardrail: - raw = result - pydantic_output, json_output = self._export_output(result) - else: - raw = result - pydantic_output, json_output = None, None - - task_output = TaskOutput( - name=self.name or self.description, - description=self.description, - expected_output=self.expected_output, - raw=raw, - pydantic=pydantic_output, - json_dict=json_output, - agent=agent.role, - output_format=self._get_output_format(), - messages=agent.last_messages, # type: ignore[attr-defined] - ) + task_output = self._build_task_output(result=result, agent=agent) if self._guardrails: for idx, guardrail in enumerate(self._guardrails): @@ -803,10 +749,12 @@ def _execute_core( if self.output_file: content = ( - json_output - if json_output + task_output.json_dict + if task_output.json_dict else ( - pydantic_output.model_dump_json() if pydantic_output else result + task_output.pydantic.model_dump_json() + if task_output.pydantic + else result ) ) self._save_file(content) @@ -1079,6 +1027,37 @@ def _export_output( return pydantic_output, json_output + def _build_task_output(self, result: Any, agent: BaseAgent) -> TaskOutput: + if isinstance(result, BaseModel): + raw = result.model_dump_json() + if self.output_pydantic: + pydantic_output = result + json_output = None + elif self.output_json: + pydantic_output = None + json_output = result.model_dump() + else: + pydantic_output = None + json_output = None + elif not self._guardrails and not self._guardrail: + raw = result + pydantic_output, json_output = self._export_output(result) + else: + raw = result + pydantic_output, json_output = None, None + + return TaskOutput( + name=self.name or self.description, + description=self.description, + expected_output=self.expected_output, + raw=raw, + pydantic=pydantic_output, + json_dict=json_output, + agent=agent.role, + output_format=self._get_output_format(), + messages=agent.last_messages, # type: ignore[attr-defined] + ) + def _get_output_format(self) -> OutputFormat: if self.output_json: return OutputFormat.JSON @@ -1240,19 +1219,7 @@ def _invoke_guardrail_function( context=context, tools=tools, ) - - pydantic_output, json_output = self._export_output(result) - task_output = TaskOutput( - name=self.name or self.description, - description=self.description, - expected_output=self.expected_output, - raw=result, - pydantic=pydantic_output, - json_dict=json_output, - agent=agent.role, - output_format=self._get_output_format(), - messages=agent.last_messages, # type: ignore[attr-defined] - ) + task_output = self._build_task_output(result=result, agent=agent) return task_output @@ -1336,18 +1303,6 @@ async def _ainvoke_guardrail_function( context=context, tools=tools, ) - - pydantic_output, json_output = self._export_output(result) - task_output = TaskOutput( - name=self.name or self.description, - description=self.description, - expected_output=self.expected_output, - raw=result, - pydantic=pydantic_output, - json_dict=json_output, - agent=agent.role, - output_format=self._get_output_format(), - messages=agent.last_messages, # type: ignore[attr-defined] - ) + task_output = self._build_task_output(result=result, agent=agent) return task_output diff --git a/lib/crewai/tests/task/test_async_task.py b/lib/crewai/tests/task/test_async_task.py index 70fec377d2..46359e0d46 100644 --- a/lib/crewai/tests/task/test_async_task.py +++ b/lib/crewai/tests/task/test_async_task.py @@ -5,6 +5,7 @@ from crewai.agent import Agent from crewai.task import Task +from pydantic import BaseModel from crewai.tasks.task_output import TaskOutput from crewai.tasks.output_format import OutputFormat @@ -285,6 +286,49 @@ def guardrail_fn(output: TaskOutput) -> tuple[bool, str]: assert result is not None assert call_count == 2 + @pytest.mark.asyncio + @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) + async def test_ainvoke_guardrail_retry_preserves_pydantic_output( + self, mock_execute: AsyncMock, test_agent: Agent + ) -> None: + """Test async guardrail retries preserve structured BaseModel outputs.""" + + class FamilyList(BaseModel): + families: list[str] + + mock_execute.side_effect = [ + FamilyList(families=["first"]), + FamilyList(families=["second"]), + ] + call_count = 0 + + def guardrail_fn( + output: TaskOutput, + ) -> tuple[bool, str | TaskOutput]: + nonlocal call_count + call_count += 1 + assert isinstance(output.pydantic, FamilyList) + if call_count == 1: + return False, "Try again with a corrected family list" + return True, output + + task = Task( + description="Test task", + expected_output="Test output", + agent=test_agent, + output_pydantic=FamilyList, + guardrail=guardrail_fn, + guardrail_max_retries=1, + ) + + result = await task.aexecute_sync() + + assert result is not None + assert call_count == 2 + assert isinstance(result.pydantic, FamilyList) + assert result.pydantic.families == ["second"] + assert result.raw == '{"families":["second"]}' + @pytest.mark.asyncio @patch("crewai.Agent.aexecute_task", new_callable=AsyncMock) async def test_ainvoke_guardrail_max_retries_exceeded( @@ -383,4 +427,4 @@ async def test_aexecute_sync_task_output_attributes( assert result.description == "Test description" assert result.expected_output == "Test expected" assert result.raw == "Test result" - assert result.agent == "Test Agent" \ No newline at end of file + assert result.agent == "Test Agent" diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py index b1ad9e2df4..afdf800c62 100644 --- a/lib/crewai/tests/test_checkpoint.py +++ b/lib/crewai/tests/test_checkpoint.py @@ -11,16 +11,18 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel from crewai.agent.core import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.crew import Crew -from crewai.flow.flow import Flow, start +from crewai.flow.flow import Flow, FlowState, start from crewai.state.checkpoint_config import CheckpointConfig from crewai.state.checkpoint_listener import ( _find_checkpoint, _resolve, _SENTINEL, + _do_checkpoint, ) from crewai.state.provider.json_provider import JsonProvider from crewai.state.provider.sqlite_provider import SqliteProvider @@ -187,6 +189,38 @@ def test_trigger_events(self) -> None: assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"} +def test_flow_checkpoint_handles_nested_pydantic_models(tmp_path) -> None: + class Family(BaseModel): + family_id: int + name: str + + class FamilyState(FlowState): + families: list[Family] = [] + + class FamilyFlow(Flow[FamilyState]): + initial_state = FamilyState + + @start() + def create_family(self) -> None: + self.state.families = [Family(family_id=1, name="Smith")] + + checkpoint_dir = str(tmp_path / "checkpoints") + cfg = CheckpointConfig( + location=checkpoint_dir, + on_events=["method_execution_finished"], + provider=JsonProvider(), + ) + flow = FamilyFlow() + flow.kickoff() + + state = RuntimeState(root=[flow]) + _do_checkpoint(state, cfg) + branch_dir = os.path.join(checkpoint_dir, "main") + assert os.path.isdir(branch_dir) + files = [name for name in os.listdir(branch_dir) if name.endswith(".json")] + assert files + + # ---------- RuntimeState lineage ---------- diff --git a/lib/crewai/tests/test_task.py b/lib/crewai/tests/test_task.py index 21356c3b42..d655a1a8fd 100644 --- a/lib/crewai/tests/test_task.py +++ b/lib/crewai/tests/test_task.py @@ -289,6 +289,55 @@ def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]: ) +def test_guardrail_retry_preserves_pydantic_output(): + class FamilyList(BaseModel): + families: list[str] + + attempts = 0 + + def guardrail_fn( + output: TaskOutput, + ) -> tuple[bool, str | TaskOutput]: + nonlocal attempts + attempts += 1 + assert isinstance(output.pydantic, FamilyList) + if attempts == 1: + return False, "Try again with a corrected family list" + return True, output + + researcher = Agent( + role="Researcher", + goal="Return structured family data", + backstory="You produce structured family data.", + allow_delegation=False, + ) + + task = Task( + description="Generate a family list.", + expected_output="A FamilyList response.", + output_pydantic=FamilyList, + agent=researcher, + guardrail=guardrail_fn, + guardrail_max_retries=1, + ) + + with patch.object( + Agent, + "execute_task", + side_effect=[ + FamilyList(families=["first"]), + FamilyList(families=["second"]), + ], + ): + result = task.execute_sync() + + assert attempts == 2 + assert isinstance(result.pydantic, FamilyList) + assert result.pydantic.families == ["second"] + assert json.loads(result.raw) == {"families": ["second"]} + assert task.retry_count == 1 + + @pytest.mark.vcr() def test_output_pydantic_sequential(): class ScoreOutput(BaseModel):