In [1]:
from catalog_ai.agents import supervisor, search_string_extractor
from catalog_ai.tools.hugging_face import get_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import TypedDict, Annotated, Sequence
import operator

class AgentState(TypedDict):
    prompt: str
    search_string: str
    raw_datasets: list[dict]
    actions: Annotated[Sequence[dict], operator.add]


In [13]:
def call_supervisor(state: AgentState):
    user_question = state['prompt']
    search_string = state['search_string']
    length_of_data = len(state['raw_datasets']) if state['raw_datasets'] else 0

    response = {
        'next_action': '',
        'action_type': ''
    }

    if search_string is None:
        response['next_action'] = 'search_string_extractor'
    elif length_of_data == 0:
        response['next_action'] = 'get_datasets'
    else:
        response['next_action'] = 'END'

    
    return {'actions': [response]}
    


In [4]:
def call_search_string_extractor(state: AgentState):
    user_question = state['prompt']

    response = search_string_extractor.search_string_extractor_runnable.invoke(
        {'user_question': user_question}
    )

    return {'search_string': response['search_string'], 'actions': [response]}

In [5]:
def call_get_datasets(state: AgentState):
    search_string = state['search_string']

    print(f"search_string: {search_string}")

    response = get_datasets(search_string=search_string)

    return {'raw_datasets': response, 'actions': []}

In [6]:
def routing_edge(state: AgentState):
    actions = state['actions']
    next_action = actions[-1]

    if next_action['next_action'] == 'END':
        return 'end'
    else:
        return next_action['next_action']


In [14]:
from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

workflow.add_node('supervisor', call_supervisor)
workflow.add_node('search_string_extractor', call_search_string_extractor)
workflow.add_node('get_datasets', call_get_datasets)

workflow.add_conditional_edges(
    'supervisor',
    routing_edge,
    {
        'search_string_extractor': 'search_string_extractor',
        'get_datasets': 'get_datasets',
        'end': END
    }
)

workflow.add_edge('search_string_extractor', 'supervisor')
workflow.add_edge('get_datasets', 'supervisor')

workflow.set_entry_point('supervisor')

app = workflow.compile()

In [15]:
inputs = {'prompt': 'Get me a dataset to train a model on computer vision.'}

for output in app.stream(inputs):
    print(f"Output from the graph: {output}")
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from the graph: {'supervisor': {'actions': [{'next_action': 'search_string_extractor', 'action_type': ''}]}}
Output from node 'supervisor':
---
{'actions': [{'next_action': 'search_string_extractor', 'action_type': ''}]}

---

Output from the graph: {'search_string_extractor': {'search_string': 'computer vision', 'actions': [{'search_string': 'computer vision', 'next_action': 'supervisor', 'action_type': 'agent', 'action_parameters': {}}]}}
Output from node 'search_string_extractor':
---
{'search_string': 'computer vision', 'actions': [{'search_string': 'computer vision', 'next_action': 'supervisor', 'action_type': 'agent', 'action_parameters': {}}]}

---

Output from the graph: {'supervisor': {'actions': [{'next_action': 'get_datasets', 'action_type': ''}]}}
Output from node 'supervisor':
---
{'actions': [{'next_action': 'get_datasets', 'action_type': ''}]}

---

search_string: computer vision
Output from the graph: {'get_datasets': {'raw_datasets': [DatasetInfo(id='MiKAI13/aut

KeyError: 'end'