Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/agent/custom_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
ActionResult,
AgentHistoryList,
AgentOutput,
AgentHistory,
)
from browser_use.browser.browser import Browser
from browser_use.browser.context import BrowserContext
from browser_use.browser.views import BrowserStateHistory
from browser_use.controller.service import Controller
from browser_use.telemetry.views import (
AgentEndTelemetryEvent,
Expand All @@ -34,6 +36,7 @@
from langchain_core.messages import (
BaseMessage,
)
from src.utils.agent_state import AgentState

from .custom_massage_manager import CustomMassageManager
from .custom_views import CustomAgentOutput, CustomAgentStepInfo
Expand Down Expand Up @@ -72,6 +75,7 @@ def __init__(
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
agent_state: AgentState = None,
):
super().__init__(
task=task,
Expand All @@ -92,6 +96,7 @@ def __init__(
tool_call_in_content=tool_call_in_content,
)
self.add_infos = add_infos
self.agent_state = agent_state
self.message_manager = CustomMassageManager(
llm=self.llm,
task=self.task,
Expand Down Expand Up @@ -367,9 +372,21 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:
)

for step in range(max_steps):
# 1) Check if stop requested
if self.agent_state and self.agent_state.is_stop_requested():
logger.info("🛑 Stop requested by user")
self._create_stop_history_item()
break

# 2) Store last valid state before step
if self.browser_context and self.agent_state:
state = await self.browser_context.get_state(use_vision=self.use_vision)
self.agent_state.set_last_valid_state(state)

if self._too_many_failures():
break

# 3) Do the step
await self.step(step_info)

if self.history.is_done():
Expand Down Expand Up @@ -403,3 +420,61 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList:

if self.generate_gif:
self.create_history_gif()

def _create_stop_history_item(self):
"""Create a history item for when the agent is stopped."""
try:
# Attempt to retrieve the last valid state from agent_state
state = None
if self.agent_state:
last_state = self.agent_state.get_last_valid_state()
if last_state:
# Convert to BrowserStateHistory
state = BrowserStateHistory(
url=getattr(last_state, 'url', ""),
title=getattr(last_state, 'title', ""),
tabs=getattr(last_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(last_state, 'screenshot', None)
)
else:
state = self._create_empty_state()
else:
state = self._create_empty_state()

# Create a final item in the agent history indicating done
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)

except Exception as e:
logger.error(f"Error creating stop history item: {e}")
# Create empty state as fallback
state = self._create_empty_state()
stop_history = AgentHistory(
model_output=None,
state=state,
result=[ActionResult(extracted_content=None, error=None, is_done=True)]
)
self.history.history.append(stop_history)

def _convert_to_browser_state_history(self, browser_state):
return BrowserStateHistory(
url=getattr(browser_state, 'url', ""),
title=getattr(browser_state, 'title', ""),
tabs=getattr(browser_state, 'tabs', []),
interacted_element=[None],
screenshot=getattr(browser_state, 'screenshot', None)
)

def _create_empty_state(self):
return BrowserStateHistory(
url="",
title="",
tabs=[],
interacted_element=[None],
screenshot=None
)
30 changes: 30 additions & 0 deletions src/utils/agent_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio

class AgentState:
_instance = None

def __init__(self):
if not hasattr(self, '_stop_requested'):
self._stop_requested = asyncio.Event()
self.last_valid_state = None # store the last valid browser state

def __new__(cls):
if cls._instance is None:
cls._instance = super(AgentState, cls).__new__(cls)
return cls._instance

def request_stop(self):
self._stop_requested.set()

def clear_stop(self):
self._stop_requested.clear()
self.last_valid_state = None

def is_stop_requested(self):
return self._stop_requested.is_set()

def set_last_valid_state(self, state):
self.last_valid_state = state

def get_last_valid_state(self):
return self.last_valid_state
Loading