Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions src/utils/deep_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
from src.utils import utils
from src.agent.custom_agent import CustomAgent
import json
import re
from browser_use.agent.service import Agent
from browser_use.browser.browser import BrowserConfig, Browser
from langchain.schema import SystemMessage, HumanMessage
from json_repair import repair_json
from src.agent.custom_prompts import CustomSystemPrompt, CustomAgentMessagePrompt
from src.controller.custom_controller import CustomController
from src.browser.custom_browser import CustomBrowser

logger = logging.getLogger(__name__)

async def deep_research(task, llm, **kwargs):
async def deep_research(task, llm, agent_state, **kwargs):
task_id = str(uuid4())
save_dir = kwargs.get("save_dir", os.path.join(f"./tmp/deep_research/{task_id}"))
logger.info(f"Save Deep Research at: {save_dir}")
Expand Down Expand Up @@ -113,12 +115,20 @@ async def deep_research(task, llm, **kwargs):
"""
record_messages = [SystemMessage(content=record_system_prompt)]

browser = Browser(
config=BrowserConfig(
disable_security=True,
headless=kwargs.get("headless", False), # Set to False to see browser actions
)
)
use_own_browser = kwargs.get("use_own_browser", False)
extra_chromium_args = []
if use_own_browser:
# if use own browser, max query num should be 1 per iter
max_query_num = 1
chrome_path = os.getenv("CHROME_PATH", None)
if chrome_path == "":
chrome_path = None
chrome_user_data = os.getenv("CHROME_USER_DATA", None)
if chrome_user_data:
extra_chromium_args += [f"--user-data-dir={chrome_user_data}"]
else:
chrome_path = None
browser = None
controller = CustomController()

search_iteration = 0
Expand Down Expand Up @@ -151,6 +161,7 @@ async def deep_research(task, llm, **kwargs):
if not query_tasks:
break
else:
query_tasks = query_tasks[:max_query_num]
history_query.extend(query_tasks)
logger.info("Query tasks:")
logger.info(query_tasks)
Expand All @@ -159,6 +170,15 @@ async def deep_research(task, llm, **kwargs):
# Paralle BU agents
add_infos = "1. Please click on the most relevant link to get information and go deeper, instead of just staying on the search page. \n" \
"2. When opening a PDF file, please remember to extract the content using extract_content instead of simply opening it for the user to view."
if use_own_browser:
browser = CustomBrowser(
config=BrowserConfig(
headless=kwargs.get("headless", False),
disable_security=kwargs.get("disable_security", True),
chrome_instance_path=chrome_path,
extra_chromium_args=extra_chromium_args,
)
)
agents = [CustomAgent(
task=task,
llm=llm,
Expand All @@ -168,15 +188,24 @@ async def deep_research(task, llm, **kwargs):
system_prompt_class=CustomSystemPrompt,
agent_prompt_class=CustomAgentMessagePrompt,
max_actions_per_step=5,
controller=controller
controller=controller,
agent_state=agent_state
) for task in query_tasks]
query_results = await asyncio.gather(*[agent.run(max_steps=kwargs.get("max_steps", 10)) for agent in agents])

if browser:
await browser.close()
browser = None
logger.info("Browser closed.")
if agent_state and agent_state.is_stop_requested():
# Stop
break
# 3. Summarize Search Result
query_result_dir = os.path.join(save_dir, "query_results")
os.makedirs(query_result_dir, exist_ok=True)
for i in range(len(query_tasks)):
query_result = query_results[i].final_result()
if not query_result:
continue
querr_save_path = os.path.join(query_result_dir, f"{search_iteration}-{i}.md")
logger.info(f"save query: {query_tasks[i]} at {querr_save_path}")
with open(querr_save_path, "w", encoding="utf-8") as fw:
Expand Down Expand Up @@ -244,7 +273,9 @@ async def deep_research(task, llm, **kwargs):
logger.info(ai_report_msg.reasoning_content)
logger.info("🤯 End Report Deep Thinking")
report_content = ai_report_msg.content

# Remove ```markdown or ``` at the *very beginning* and ``` at the *very end*, with optional whitespace
report_content = re.sub(r"^```\s*markdown\s*|^\s*```|```\s*$", "", report_content, flags=re.MULTILINE)
report_content = report_content.strip()
report_file_path = os.path.join(save_dir, "final_report.md")
with open(report_file_path, "w", encoding="utf-8") as f:
f.write(report_content)
Expand All @@ -257,4 +288,5 @@ async def deep_research(task, llm, **kwargs):
finally:
if browser:
await browser.close()
browser = None
logger.info("Browser closed.")
154 changes: 97 additions & 57 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,31 @@ async def stop_agent():
gr.update(value="Stop", interactive=True),
gr.update(interactive=True)
)

async def stop_research_agent():
"""Request the agent to stop and update UI with enhanced feedback"""
global _global_agent_state, _global_browser_context, _global_browser

try:
# Request stop
_global_agent_state.request_stop()

# Update UI immediately
message = "Stop requested - the agent will halt at the next safe point"
logger.info(f"🛑 {message}")

# Return UI updates
return ( # errors_output
gr.update(value="Stopping...", interactive=False), # stop_button
gr.update(interactive=False), # run_button
)
except Exception as e:
error_msg = f"Error during stop: {str(e)}"
logger.error(error_msg)
return (
gr.update(value="Stop", interactive=True),
gr.update(interactive=True)
)

async def run_browser_agent(
agent_type,
Expand Down Expand Up @@ -598,8 +623,12 @@ async def close_global_browser():
await _global_browser.close()
_global_browser = None

async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, headless):
async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless):
from src.utils.deep_research import deep_research
global _global_agent_state

# Clear any previous stop request
_global_agent_state.clear_stop()

llm = utils.get_llm_model(
provider=llm_provider,
Expand All @@ -608,12 +637,15 @@ async def run_deep_search(research_task, max_search_iteration_input, max_query_p
base_url=llm_base_url,
api_key=llm_api_key,
)
markdown_content, file_path = await deep_research(research_task, llm,
markdown_content, file_path = await deep_research(research_task, llm, _global_agent_state,
max_search_iterations=max_search_iteration_input,
max_query_num=max_query_per_iter_input,
use_vision=use_vision,
headless=headless)
return markdown_content, file_path
headless=headless,
use_own_browser=use_own_browser
)

return markdown_content, file_path, gr.update(value="Stop", interactive=True), gr.update(interactive=True)


def create_ui(config, theme_name="Ocean"):
Expand Down Expand Up @@ -815,57 +847,17 @@ def create_ui(config, theme_name="Ocean"):
label="Live Browser View",
)

with gr.TabItem("🧐 Deep Research"):
with gr.Group():
research_task_input = gr.Textbox(label="Research Task", lines=5, value="Compose a report on the use of Reinforcement Learning for training Large Language Models, encompassing its origins, current advancements, and future prospects, substantiated with examples of relevant models and techniques. The report should reflect original insights and analysis, moving beyond mere summarization of existing literature.")
with gr.Row():
max_search_iteration_input = gr.Number(label="Max Search Iteration", value=20, precision=0) # precision=0 确保是整数
max_query_per_iter_input = gr.Number(label="Max Query per Iteration", value=5, precision=0) # precision=0 确保是整数
research_button = gr.Button("Run Deep Research")
markdown_output_display = gr.Markdown(label="Research Report")
markdown_download = gr.File(label="Download Research Report")


with gr.TabItem("📁 Configuration", id=5):
with gr.Group():
config_file_input = gr.File(
label="Load Config File",
file_types=[".pkl"],
interactive=True
)

load_config_button = gr.Button("Load Existing Config From File", variant="primary")
save_config_button = gr.Button("Save Current Config", variant="primary")

config_status = gr.Textbox(
label="Status",
lines=2,
interactive=False
)

load_config_button.click(
fn=update_ui_from_config,
inputs=[config_file_input],
outputs=[
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, enable_recording,
window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path,
task, config_status
]
)
with gr.TabItem("🧐 Deep Research", id=5):
research_task_input = gr.Textbox(label="Research Task", lines=5, value="Compose a report on the use of Reinforcement Learning for training Large Language Models, encompassing its origins, current advancements, and future prospects, substantiated with examples of relevant models and techniques. The report should reflect original insights and analysis, moving beyond mere summarization of existing literature.")
with gr.Row():
max_search_iteration_input = gr.Number(label="Max Search Iteration", value=20, precision=0) # precision=0 确保是整数
max_query_per_iter_input = gr.Number(label="Max Query per Iteration", value=5, precision=0) # precision=0 确保是整数
with gr.Row():
research_button = gr.Button("▶️ Run Deep Research", variant="primary", scale=2)
stop_research_button = gr.Button("⏹️ Stop", variant="stop", scale=1)
markdown_output_display = gr.Markdown(label="Research Report")
markdown_download = gr.File(label="Download Research Report")

save_config_button.click(
fn=save_current_config,
inputs=[
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security,
enable_recording, window_w, window_h, save_recording_path, save_trace_path,
save_agent_history_path, task,
],
outputs=[config_status]
)

with gr.TabItem("📊 Results", id=6):
with gr.Group():
Expand Down Expand Up @@ -929,9 +921,15 @@ def create_ui(config, theme_name="Ocean"):
# Run Deep Research
research_button.click(
fn=run_deep_search,
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, headless],
outputs=[markdown_output_display, markdown_download]
)
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless],
outputs=[markdown_output_display, markdown_download, stop_research_button, research_button]
)
# Bind the stop button click event after errors_output is defined
stop_research_button.click(
fn=stop_research_agent,
inputs=[],
outputs=[stop_research_button, research_button],
)

with gr.TabItem("🎥 Recordings", id=7):
def list_recordings(save_recording_path):
Expand Down Expand Up @@ -966,6 +964,48 @@ def list_recordings(save_recording_path):
inputs=save_recording_path,
outputs=recordings_gallery
)

with gr.TabItem("📁 Configuration", id=8):
with gr.Group():
config_file_input = gr.File(
label="Load Config File",
file_types=[".pkl"],
interactive=True
)

load_config_button = gr.Button("Load Existing Config From File", variant="primary")
save_config_button = gr.Button("Save Current Config", variant="primary")

config_status = gr.Textbox(
label="Status",
lines=2,
interactive=False
)

load_config_button.click(
fn=update_ui_from_config,
inputs=[config_file_input],
outputs=[
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, enable_recording,
window_w, window_h, save_recording_path, save_trace_path, save_agent_history_path,
task, config_status
]
)

save_config_button.click(
fn=save_current_config,
inputs=[
agent_type, max_steps, max_actions_per_step, use_vision, tool_calling_method,
llm_provider, llm_model_name, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security,
enable_recording, window_w, window_h, save_recording_path, save_trace_path,
save_agent_history_path, task,
],
outputs=[config_status]
)


# Attach the callback to the LLM provider dropdown
llm_provider.change(
Expand Down