In [None]:
from typing import TypedDict

from browser import ChromeBrowser
from browser.search.web import BraveBrowser
from generalist.agents.core import CapabilityPlan
from generalist.tools.data_model import ContentResource, ShortAnswer, Task


class ExecutionState(TypedDict):
    # what user is asking to do for them 
    ask: str
    # Identifies what the original question/task given, which objective it got transferred to, what the plan to get an answer is
    task: Task
    # order index of the step of the task's plan that is being executed 
    step: int
    # Clues, findings and answers to the previous subtasks
    # Used to execute a capability plan step given already found information
    context: str  
    # capability plan for this task (overwritten when a new subtask from the main plain is picked up)
    capability_plan: CapabilityPlan
    # capability plan step order 
    capability_plan_step: int
    # answers to subtask, the last one should be the final answer to the task 
    answers: list[ShortAnswer]
    # all text resources that might be needed to execute the task
    resources: list[ContentResource]
    # tools that already got called
    # TODO: see if this is needed 
    tools_called: str

MAX_STEPS = 5
BRAVE_SEARCH = BraveBrowser(browser=ChromeBrowser(), session_id="deep_web_search")

In [None]:
import json
from generalist.agents.core import AgentCapabilityDeepWebSearch, AgentCapabilityUnstructuredDataProcessor, \
    AgentCapabilityCodeWriterExecutor, AgentCapabilityAudioProcessor, AgentCapabilityOutput
from generalist.tools.planning import determine_capabilities, create_plan
from generalist.tools.summarisers import construct_short_answer
from langgraph.graph import StateGraph, START, END


def init_state(ask: str, resources: list[ContentResource] = list()) -> ExecutionState:
    # TODO: should I be using LLM to convert attachments/resources to acceptable format?
    # TODO: implement proper handling of attachments and resources 
    return ExecutionState(
        ask=ask,
        task=None,
        step=None,
        context="",
        answers=list(),
        resources=resources,
    )

def set_task(state: ExecutionState) -> ExecutionState:
    question_task = state["ask"]
    resources = state["resources"]
    task_plan_response = create_plan(question_task, resources)

    result = json.loads(task_plan_response)
    task = Task(
      question=question_task,
      objective=result["objective"],
      plan=result["plan"],
    )
    identified_resource = result.get("resource", None)
    if identified_resource:
        task_resource = ContentResource(
            provided_by="user",
            content=identified_resource.get("content", None),
            link=identified_resource.get("link", None),
            metadata={},
        )
        state["resources"].append(task_resource)
    state["task"] = task

    state["step"] = 0
    return state

def evaluate_task_completion(state: ExecutionState) -> str:
    short_answer = construct_short_answer(
        state["task"].objective,
        state["context"]
    )

    # Early stopping if answer exists
    if short_answer.answered:
        return "end"

    # Early stopping if maximum number of steps reached
    if state['step'] > MAX_STEPS:
        return "end"

    return "continue"

def plan_next_step(state: ExecutionState) -> ExecutionState:
    # Automatically determine which step to execute based on context
    capability_plan_json = determine_capabilities(
        task=state["task"],
        context=state["context"]
    )

    state["capability_plan"] = CapabilityPlan.from_json(capability_plan_json)
    return state

def execute(state: ExecutionState) -> ExecutionState:
    activity, capability = state["capability_plan"].subplan[0]
    output = AgentCapabilityOutput(activity)
    capability_agent = capability(activity)

    if capability is AgentCapabilityDeepWebSearch:
        # Reinitiate the agent since it might need browser
        capability_agent = capability(activity=activity, search_browser=BRAVE_SEARCH)
        output = capability_agent.run()
    elif capability is AgentCapabilityUnstructuredDataProcessor:
        output = capability_agent.run(state["resources"])
    elif capability is AgentCapabilityCodeWriterExecutor:
        output = capability_agent.run(state["resources"])
    elif capability is AgentCapabilityAudioProcessor:
        output = capability_agent.run(state["resources"])
    else:
        print("DEBUG | run_capability | Call to unidentified agent: ", capability)

    if output.answers:
        state["answers"].extend(output.answers)
    if output.resources:
        state["resources"].extend(output.resources)

    # Update context with step results
    state["context"] += f"\nStep {state['step']}: {state['answers']}"

    # FIXME: after deep web search we need to pass the results to unstructered text processing
    #  it can be a lot of text after deep web search and we probably want to analyse it once with (unstructered text processing)
    #  and then clear out.
    # Here I assume that unstructered text processing call comes directly after deep web search and we have finished analysing the results of deep web search
    if capability is AgentCapabilityUnstructuredDataProcessor:
        state["resources"] = [ r for r in state["resources"]  if r.provided_by != AgentCapabilityDeepWebSearch.name ]


    state["step"] += 1

    return state


workflow = StateGraph(state_schema=ExecutionState)

workflow.add_node("set_task", set_task)
workflow.add_node("plan_next_step", plan_next_step)
workflow.add_node("execute", execute)

workflow.add_edge(START, "set_task")
workflow.add_conditional_edges(
    "set_task",
    evaluate_task_completion,
    {
        "continue": "plan_next_step",
        "end": END,
    }
)
workflow.add_edge("plan_next_step", "execute")
workflow.add_conditional_edges(
    "execute",
    evaluate_task_completion,
    {
        "continue": "plan_next_step",
        "end": END,
    }
)

generalist_graph = workflow.compile()

from IPython.display import Image, display
display(Image(generalist_graph.get_graph().draw_mermaid_png()))

In [None]:
# Test part nodes & logging
import logging

import mlflow

from generalist.models.core import MLFlowLLMWrapper
from generalist.utils import pprint
from generalist.tools import planning, web_search, text_processing, code

# STARTING TO LOG EVERYTHING (MANUALLY ADDED)
logging.getLogger().setLevel(logging.INFO)

mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
# MONKEY PATCH the llm call
planning.llm = MLFlowLLMWrapper(planning.llm) 
web_search.llm = MLFlowLLMWrapper(web_search.llm)
text_processing.llm = MLFlowLLMWrapper(text_processing.llm)
code.llm = MLFlowLLMWrapper(code.llm)

# FIXME: delete the if-statement, if wanting to test at this stage
if False: 
    question = "Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M."
    initial_state = init_state(question)

    state = set_task(initial_state)
    pprint(state["task"].__str__())

In [None]:
raise InterruptedError

In [None]:
# Test 1
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question = "Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M."

initial_state = init_state(question)
final_state = generalist_graph.invoke(initial_state)

In [None]:
# Task 2 
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question = "How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?"

initial_state = init_state(question)
final_state = generalist_graph.invoke(initial_state)

In [None]:
# Task 3
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question = "The attached Excel file contains the sales of menu items for a local fast-food chain. " \
"What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."

sources = [
    ContentResource(
        provided_by="user", 
        content="", 
        link="./7bd855d8-463d-4ed5-93ca-5fe35145f733.xls",
        metadata={},
    )
]
initial_state = init_state(question, resources=sources)
final_state = generalist_graph.invoke(initial_state)
final_state

In [None]:
# Task 4 : PDF 
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

# TODO: check if this task actually requires PDF processing 
question = "Where were the Vietnamese specimens described by Kuznetzov in Nedoshivina's 2010 paper eventually deposited? " \
"Just give me the city name without abbreviations."

sources = []
initial_state = init_state(question, resources=sources)
final_state = generalist_graph.invoke(initial_state)
final_state

In [None]:
# Task 5: search and Information Retrieval from just text 
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question =\
"""What country had the least number of athletes at the 1928 Summer Olympics? If there's a tie for a number of athletes, return the first in alphabetical order. 
Give the IOC country code as your answer."""

initial_state = init_state(question)
final_state = generalist_graph.invoke(initial_state)

In [None]:
# Task 6: python code execution  
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question =\
"""What is the final numeric output from the attached Python code?"""

sources = [
    ContentResource(
        provided_by="user", 
        content="", 
        link="/Users/maksim.rostov/pdev/freestyling/agents/hf-course/unit4_general_agent/generalist/notebooks/f918266a-b3e0-4914-865d-4faa564f1aef.py",
        metadata={},
    )
]
initial_state = init_state(question, resources=sources)
final_state = generalist_graph.invoke(initial_state)
final_state

In [None]:
# Task 7: loading video file, transcribing it.  
# MLFlow good at logging the state of the index and basic calls to LLM
mlflow.langchain.autolog()
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment("test_generalist")
mlflow.models.set_model(generalist_graph)

logging.getLogger().setLevel(logging.INFO)

question =\
"""Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec. What does Teal'c say in response to the question "Isn't that hot?"""

initial_state = init_state(question)
final_state = generalist_graph.invoke(initial_state)
final_state

In [None]:
# Try evaluation with GAIA
from datasets import load_dataset
from huggingface_hub import snapshot_download

gaia_path = "~/pdev/freelectron/free-generalist/evaluation/gaia"
data_dir = snapshot_download(local_dir=gaia_path,  local_files_only=True, repo_id="gaia-benchmark/GAIA", repo_type="dataset")

dataset = load_dataset(data_dir, "2023_level1", split="validation")
gaia_keys = ['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata']
for sample in dataset:
    [ print(k, "=", sample[k]) for k in gaia_keys]

    mlflow.langchain.autolog()
    mlflow.set_tracking_uri('http://localhost:5000')
    mlflow.set_experiment(f"gaia_{"_".join(sample["task_id"].split("-"))}")
    mlflow.models.set_model(generalist_graph)
    logging.getLogger().setLevel(logging.INFO)

    question = sample["Question"]
    initial_state = init_state(question)
    final_state = generalist_graph.invoke(initial_state)
    break