diff --git a/src/agent/custom_agent.py b/src/agent/custom_agent.py index f4c1df5a..ff8908c8 100644 --- a/src/agent/custom_agent.py +++ b/src/agent/custom_agent.py @@ -20,9 +20,11 @@ ActionResult, AgentHistoryList, AgentOutput, + AgentHistory, ) from browser_use.browser.browser import Browser from browser_use.browser.context import BrowserContext +from browser_use.browser.views import BrowserStateHistory from browser_use.controller.service import Controller from browser_use.telemetry.views import ( AgentEndTelemetryEvent, @@ -34,6 +36,7 @@ from langchain_core.messages import ( BaseMessage, ) +from src.utils.agent_state import AgentState from .custom_massage_manager import CustomMassageManager from .custom_views import CustomAgentOutput, CustomAgentStepInfo @@ -72,6 +75,7 @@ def __init__( max_error_length: int = 400, max_actions_per_step: int = 10, tool_call_in_content: bool = True, + agent_state: AgentState = None, ): super().__init__( task=task, @@ -92,6 +96,7 @@ def __init__( tool_call_in_content=tool_call_in_content, ) self.add_infos = add_infos + self.agent_state = agent_state self.message_manager = CustomMassageManager( llm=self.llm, task=self.task, @@ -367,9 +372,21 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList: ) for step in range(max_steps): + # 1) Check if stop requested + if self.agent_state and self.agent_state.is_stop_requested(): + logger.info("🛑 Stop requested by user") + self._create_stop_history_item() + break + + # 2) Store last valid state before step + if self.browser_context and self.agent_state: + state = await self.browser_context.get_state(use_vision=self.use_vision) + self.agent_state.set_last_valid_state(state) + if self._too_many_failures(): break + # 3) Do the step await self.step(step_info) if self.history.is_done(): @@ -403,3 +420,61 @@ async def run(self, max_steps: int = 100) -> AgentHistoryList: if self.generate_gif: self.create_history_gif() + + def _create_stop_history_item(self): + """Create a history item for when the agent is stopped.""" + try: + # Attempt to retrieve the last valid state from agent_state + state = None + if self.agent_state: + last_state = self.agent_state.get_last_valid_state() + if last_state: + # Convert to BrowserStateHistory + state = BrowserStateHistory( + url=getattr(last_state, 'url', ""), + title=getattr(last_state, 'title', ""), + tabs=getattr(last_state, 'tabs', []), + interacted_element=[None], + screenshot=getattr(last_state, 'screenshot', None) + ) + else: + state = self._create_empty_state() + else: + state = self._create_empty_state() + + # Create a final item in the agent history indicating done + stop_history = AgentHistory( + model_output=None, + state=state, + result=[ActionResult(extracted_content=None, error=None, is_done=True)] + ) + self.history.history.append(stop_history) + + except Exception as e: + logger.error(f"Error creating stop history item: {e}") + # Create empty state as fallback + state = self._create_empty_state() + stop_history = AgentHistory( + model_output=None, + state=state, + result=[ActionResult(extracted_content=None, error=None, is_done=True)] + ) + self.history.history.append(stop_history) + + def _convert_to_browser_state_history(self, browser_state): + return BrowserStateHistory( + url=getattr(browser_state, 'url', ""), + title=getattr(browser_state, 'title', ""), + tabs=getattr(browser_state, 'tabs', []), + interacted_element=[None], + screenshot=getattr(browser_state, 'screenshot', None) + ) + + def _create_empty_state(self): + return BrowserStateHistory( + url="", + title="", + tabs=[], + interacted_element=[None], + screenshot=None + ) diff --git a/src/utils/agent_state.py b/src/utils/agent_state.py new file mode 100644 index 00000000..487a8105 --- /dev/null +++ b/src/utils/agent_state.py @@ -0,0 +1,30 @@ +import asyncio + +class AgentState: + _instance = None + + def __init__(self): + if not hasattr(self, '_stop_requested'): + self._stop_requested = asyncio.Event() + self.last_valid_state = None # store the last valid browser state + + def __new__(cls): + if cls._instance is None: + cls._instance = super(AgentState, cls).__new__(cls) + return cls._instance + + def request_stop(self): + self._stop_requested.set() + + def clear_stop(self): + self._stop_requested.clear() + self.last_valid_state = None + + def is_stop_requested(self): + return self._stop_requested.is_set() + + def set_last_valid_state(self, state): + self.last_valid_state = state + + def get_last_valid_state(self): + return self.last_valid_state \ No newline at end of file diff --git a/webui.py b/webui.py index bf20c122..3ba82029 100644 --- a/webui.py +++ b/webui.py @@ -6,6 +6,7 @@ # @FileName: webui.py import pdb +import logging from dotenv import load_dotenv @@ -13,6 +14,8 @@ import argparse import os +logger = logging.getLogger(__name__) + import gradio as gr import argparse @@ -27,6 +30,7 @@ BrowserContextWindowSize, ) from playwright.async_api import async_playwright +from src.utils.agent_state import AgentState from src.agent.custom_agent import CustomAgent from src.agent.custom_prompts import CustomSystemPrompt @@ -45,6 +49,36 @@ _global_browser = None _global_browser_context = None +# Create the global agent state instance +_global_agent_state = AgentState() + +async def stop_agent(): + """Request the agent to stop and update UI with enhanced feedback""" + global _global_agent_state, _global_browser_context, _global_browser + + try: + # Request stop + _global_agent_state.request_stop() + + # Update UI immediately + message = "Stop requested - the agent will halt at the next safe point" + logger.info(f"🛑 {message}") + + # Return UI updates + return ( + message, # errors_output + gr.update(value="Stopping...", interactive=False), # stop_button + gr.update(interactive=False), # run_button + ) + except Exception as e: + error_msg = f"Error during stop: {str(e)}" + logger.error(error_msg) + return ( + error_msg, + gr.update(value="Stop", interactive=True), + gr.update(interactive=True) + ) + async def run_browser_agent( agent_type, llm_provider, @@ -68,79 +102,105 @@ async def run_browser_agent( max_actions_per_step, tool_call_in_content ): - # Disable recording if the checkbox is unchecked - if not enable_recording: - save_recording_path = None - - # Ensure the recording directory exists if recording is enabled - if save_recording_path: - os.makedirs(save_recording_path, exist_ok=True) - - # Get the list of existing videos before the agent runs - existing_videos = set() - if save_recording_path: - existing_videos = set( - glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4")) - + glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]")) - ) + global _global_agent_state + _global_agent_state.clear_stop() # Clear any previous stop requests - # Run the agent - llm = utils.get_llm_model( - provider=llm_provider, - model_name=llm_model_name, - temperature=llm_temperature, - base_url=llm_base_url, - api_key=llm_api_key, - ) - if agent_type == "org": - final_result, errors, model_actions, model_thoughts = await run_org_agent( - llm=llm, - use_own_browser=use_own_browser, - keep_browser_open=keep_browser_open, - headless=headless, - disable_security=disable_security, - window_w=window_w, - window_h=window_h, - save_recording_path=save_recording_path, - save_trace_path=save_trace_path, - task=task, - max_steps=max_steps, - use_vision=use_vision, - max_actions_per_step=max_actions_per_step, - tool_call_in_content=tool_call_in_content - ) - elif agent_type == "custom": - final_result, errors, model_actions, model_thoughts = await run_custom_agent( - llm=llm, - use_own_browser=use_own_browser, - keep_browser_open=keep_browser_open, - headless=headless, - disable_security=disable_security, - window_w=window_w, - window_h=window_h, - save_recording_path=save_recording_path, - save_trace_path=save_trace_path, - task=task, - add_infos=add_infos, - max_steps=max_steps, - use_vision=use_vision, - max_actions_per_step=max_actions_per_step, - tool_call_in_content=tool_call_in_content + try: + # Disable recording if the checkbox is unchecked + if not enable_recording: + save_recording_path = None + + # Ensure the recording directory exists if recording is enabled + if save_recording_path: + os.makedirs(save_recording_path, exist_ok=True) + + # Get the list of existing videos before the agent runs + existing_videos = set() + if save_recording_path: + existing_videos = set( + glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4")) + + glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]")) + ) + + # Run the agent + llm = utils.get_llm_model( + provider=llm_provider, + model_name=llm_model_name, + temperature=llm_temperature, + base_url=llm_base_url, + api_key=llm_api_key, ) - else: - raise ValueError(f"Invalid agent type: {agent_type}") - - # Get the list of videos after the agent runs (if recording is enabled) - latest_video = None - if save_recording_path: - new_videos = set( - glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4")) - + glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]")) + if agent_type == "org": + final_result, errors, model_actions, model_thoughts = await run_org_agent( + llm=llm, + use_own_browser=use_own_browser, + keep_browser_open=keep_browser_open, + headless=headless, + disable_security=disable_security, + window_w=window_w, + window_h=window_h, + save_recording_path=save_recording_path, + save_trace_path=save_trace_path, + task=task, + max_steps=max_steps, + use_vision=use_vision, + max_actions_per_step=max_actions_per_step, + tool_call_in_content=tool_call_in_content + ) + elif agent_type == "custom": + final_result, errors, model_actions, model_thoughts = await run_custom_agent( + llm=llm, + use_own_browser=use_own_browser, + keep_browser_open=keep_browser_open, + headless=headless, + disable_security=disable_security, + window_w=window_w, + window_h=window_h, + save_recording_path=save_recording_path, + save_trace_path=save_trace_path, + task=task, + add_infos=add_infos, + max_steps=max_steps, + use_vision=use_vision, + max_actions_per_step=max_actions_per_step, + tool_call_in_content=tool_call_in_content + ) + else: + raise ValueError(f"Invalid agent type: {agent_type}") + + # Get the list of videos after the agent runs (if recording is enabled) + latest_video = None + if save_recording_path: + new_videos = set( + glob.glob(os.path.join(save_recording_path, "*.[mM][pP]4")) + + glob.glob(os.path.join(save_recording_path, "*.[wW][eE][bB][mM]")) + ) + if new_videos - existing_videos: + latest_video = list(new_videos - existing_videos)[0] # Get the first new video + + return ( + final_result, + errors, + model_actions, + model_thoughts, + latest_video, + gr.update(value="Stop", interactive=True), # Re-enable stop button + gr.update(value="Run", interactive=True) # Re-enable run button ) - if new_videos - existing_videos: - latest_video = list(new_videos - existing_videos)[0] # Get the first new video - return final_result, errors, model_actions, model_thoughts, latest_video + except Exception as e: + import traceback + traceback.print_exc() + errors = str(e) + "\n" + traceback.format_exc() + return ( + '', # final_result + errors, # errors + '', # model_actions + '', # model_thoughts + None, # latest_video + gr.update(value="Stop", interactive=True), # Re-enable stop button + gr.update(value="Run", interactive=True) # Re-enable run button + ) async def run_org_agent( @@ -161,7 +221,11 @@ async def run_org_agent( ): try: - global _global_browser, _global_browser_context + global _global_browser, _global_browser_context, _global_agent_state + + # Clear any previous stop request + _global_agent_state.clear_stop() + if use_own_browser: chrome_path = os.getenv("CHROME_PATH", None) if chrome_path == "": @@ -242,7 +306,10 @@ async def run_custom_agent( tool_call_in_content ): try: - global _global_browser, _global_browser_context + global _global_browser, _global_browser_context, _global_agent_state + + # Clear any previous stop request + _global_agent_state.clear_stop() if use_own_browser: chrome_path = os.getenv("CHROME_PATH", None) @@ -287,7 +354,8 @@ async def run_custom_agent( controller=controller, system_prompt_class=CustomSystemPrompt, max_actions_per_step=max_actions_per_step, - tool_call_in_content=tool_call_in_content + tool_call_in_content=tool_call_in_content, + agent_state=_global_agent_state ) history = await agent.run(max_steps=max_steps) @@ -550,6 +618,24 @@ def create_ui(theme_name="Ocean"): label="Model Thoughts", lines=3, show_label=True ) + # Bind the stop button click event after errors_output is defined + stop_button.click( + fn=stop_agent, + inputs=[], + outputs=[errors_output, stop_button, run_button], + ) + + # Run button click handler + run_button.click( + fn=run_browser_agent, + inputs=[ + agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, + use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path, + enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content + ], + outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display, stop_button, run_button], + ) + with gr.TabItem("🎥 Recordings", id=6): def list_recordings(save_recording_path): if not os.path.exists(save_recording_path): @@ -601,17 +687,6 @@ def list_recordings(save_recording_path): use_own_browser.change(fn=close_global_browser) keep_browser_open.change(fn=close_global_browser) - # Run button click handler - run_button.click( - fn=run_browser_agent, - inputs=[ - agent_type, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, - use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h, save_recording_path, save_trace_path, - enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_call_in_content - ], - outputs=[final_result_output, errors_output, model_actions_output, model_thoughts_output, recording_display], - ) - return demo def main():