Skip to content

Commit

Permalink
refactor: optimizing act_by_order mode
Browse files Browse the repository at this point in the history
  • Loading branch information
莘权 马 committed Apr 28, 2024
1 parent d53db1e commit 6505087
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 31 deletions.
15 changes: 3 additions & 12 deletions metagpt/roles/product_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.roles.role import Role
from metagpt.roles.role import Role, RoleReactMode
from metagpt.utils.common import any_to_name


Expand All @@ -35,17 +35,8 @@ def __init__(self, **kwargs) -> None:

self.set_actions([PrepareDocuments, WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)

async def _think(self) -> bool:
"""Decide what to do"""
if self.git_repo and not self.config.git_reinit:
self._set_state(1)
else:
self._set_state(0)
self.config.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)
self.rc.react_mode = RoleReactMode.BY_ORDER
self.todo_action = any_to_name(WritePRD)

async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)
23 changes: 9 additions & 14 deletions metagpt/roles/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ async def _think(self) -> bool:
self.recovered = False # avoid max_react_loop out of work
return True

if self.rc.react_mode == RoleReactMode.BY_ORDER:
if self.rc.max_react_loop != len(self.actions):
self.rc.max_react_loop = len(self.actions)
self._set_state(self.rc.state + 1)
return self.rc.state >= 0 and self.rc.state < len(self.actions)

prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(
history=self.rc.history,
Expand Down Expand Up @@ -455,24 +461,15 @@ async def _react(self) -> Message:
rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act
while actions_taken < self.rc.max_react_loop:
# think
await self._think()
if self.rc.todo is None:
todo = await self._think()
if not todo:
break
# act
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
rsp = await self._act()
actions_taken += 1
return rsp # return output from the last action

async def _act_by_order(self) -> Message:
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
rsp = Message(content="No actions taken yet") # return default message if actions=[]
for i in range(start_idx, len(self.states)):
self._set_state(i)
rsp = await self._act()
return rsp # return output from the last action

async def _plan_and_act(self) -> Message:
"""first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""

Expand Down Expand Up @@ -513,10 +510,8 @@ async def _act_on_task(self, current_task: Task) -> TaskResult:

async def react(self) -> Message:
"""Entry to one of three strategies by which Role reacts to the observed Message"""
if self.rc.react_mode == RoleReactMode.REACT:
if self.rc.react_mode == RoleReactMode.REACT or self.rc.react_mode == RoleReactMode.BY_ORDER:
rsp = await self._react()
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
rsp = await self._act_by_order()
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
rsp = await self._plan_and_act()
else:
Expand Down
5 changes: 0 additions & 5 deletions tests/metagpt/roles/test_product_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest

from metagpt.actions import WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.context import Context
from metagpt.logs import logger
Expand All @@ -30,11 +29,7 @@ async def test_product_manager(new_filename):
rsp = await product_manager.run(MockMessages.req)
assert context.git_repo
assert context.repo
assert rsp.cause_by == any_to_str(PrepareDocuments)
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files

# write prd
rsp = await product_manager.run(rsp)
assert rsp.cause_by == any_to_str(WritePRD)
logger.info(rsp)
assert len(rsp.content) > 0
Expand Down

0 comments on commit 6505087

Please sign in to comment.