# Setup

In [None]:
import matplotlib.pyplot as plt
import pprint

from helium import helium, ops
from helium.graphs import Graph
from helium.runtime import HeliumServerConfig

# Use local Helium server
server_config = HeliumServerConfig(
    is_local=True, llm_service_configs="../configs/llm_services.json"
)

# Use remote Helium server
# In this case, you need to start a Helium server in a separate process.
# See scripts/serve_helium.sh for an example.
# server_config = HeliumServerConfig(is_local=False)

In [None]:
# Start the Helium server locally
helium.get_started_instance(config=server_config)

# English Writing Simulation

In [None]:
def build_writing_simulation_graph(num_generations: int) -> Graph:
    """Builds a compute graph for a writing simulation task.

    Parameters
    ----------
    num_generations : int
        The number of generations to perform in the simulation.

    Returns
    -------
    Graph
        A compute graph that simulates a writing task where a user writes an introduction,
        receives feedback, and revises their writing based on the feedback.
    """
    # 1. Generate a draft introduction.
    draft_instr = ops.message_data(
        [
            ops.OpMessage(
                role="system",
                content=(
                    "You are a non-native English speaker learning English. "
                    "You are to respond in basic and sometimes broken English."
                ),
            ),
            ops.OpMessage(
                role="user",
                content=["Write a paragraph to introduce yourself."] * num_generations,
            ),
        ]
    )
    draft_history = ops.llm_chat(draft_instr, return_history=True)

    # 2. Review the draft.
    draft = ops.get_last_message(draft_history)
    review_instr = ops.message_data(
        [
            ops.OpMessage(
                role="system",
                content=(
                    "You are an English teacher. You evaluate the user's writing "
                    "critically and respond with your suggestions for improvement."
                ),
            ),
            ops.OpMessage(role="user", content=draft),
        ]
    )
    review = ops.llm_chat(review_instr, return_history=False)

    # 3. Revise the draft based on the review.
    revise_msg = ops.format_op(
        "Below is the comment on your writing. Please revise your introduction "
        "accordingly and answer me with only that revised version:\n\n{review}",
        review=review,
    )
    revise_instr = ops.append_message(
        draft_history, ops.OpMessage(role="user", content=revise_msg)
    )
    final = ops.llm_chat(revise_instr, return_history=False)

    # 4. Collect outputs.
    draft_out = ops.as_output("draft", draft)
    review_out = ops.as_output("review", review)
    final_out = ops.as_output("final", final)

    return Graph.from_ops([draft_out, review_out, final_out])

In [None]:
# Build the writing simulation graph
graph = build_writing_simulation_graph(num_generations=5)

In [None]:
# Visualize the graph
plt.figure(figsize=(10, 6))
_ = graph.visualize()

In [None]:
# Execute the graph on the Helium server
result = helium.invoke(graph)

In [None]:
# Print the results
pprint.pprint(result)

# Multiagent Debate

In [None]:
def build_multiagent_debate_graph(
    num_agents: int,
    num_rounds: int,
) -> Graph:
    system_prompt = "You are a helpful AI Assistant."
    input_op = ops.input_placeholder("questions")
    revise_prompts = (
        (
            "Can you double check that your answer is correct. Put your final answer "
            "in the form (X) at the end of your response."
        ),
        (
            "Using the reasoning from other agents as additional advice, can you give "
            "an updated answer? Examine your solution and that other agents step by "
            "step. Put your answer in the form (X) at the end of your response."
        ),
    )
    generation_config = None

    # First round
    initial_message_list = [
        [
            ops.OpMessage(role="system", content=system_prompt),
            ops.OpMessage(role="user", content=input_op),
        ]
        for _ in range(num_agents)
    ]
    history_list = [
        ops.llm_chat(message, generation_config, return_history=True)
        for message in initial_message_list
    ]

    if num_rounds == 1:
        return Graph.from_ops(
            [
                ops.as_output(f"agent_{i}", history)
                for i, history in enumerate(history_list)
            ]
        )

    # Debate rounds
    if num_agents == 1:
        revise_prompt = ops.data(revise_prompts[0])
        new_convo_list = [
            ops.append_message(history, revise_prompt) for history in history_list
        ]
    else:
        last_message_list = [ops.get_last_message(history) for history in history_list]
        new_convo_list = []
        for i, history in enumerate(history_list):
            other_agent_answers = last_message_list[:i] + last_message_list[i + 1 :]
            revise_prompt = ops.format_op(
                "\n\n ".join(
                    [
                        "These are the solutions to the problem from other agents: ",
                        *[
                            f"One agent solution: ```{{agent_{j}}}```"
                            for j in range(num_agents - 1)
                        ],
                        revise_prompts[1],
                    ]
                ),
                **{f"agent_{j}": ans for j, ans in enumerate(other_agent_answers)},
            )
            new_convo_list.append(ops.append_message(history, revise_prompt))
    revised_history_list = [
        ops.llm_chat(convo, generation_config, return_history=True)
        for convo in new_convo_list
    ]
    debate_loop = ops.loop(history_list, revised_history_list, num_rounds - 1)

    return Graph.from_ops(
        [
            ops.as_output(f"agent_{i}", agent_history)
            for i, agent_history in enumerate(debate_loop)
        ]
    )

In [None]:
# Build the writing simulation graph
graph = build_multiagent_debate_graph(num_agents=3, num_rounds=3)

In [None]:
# Visualize the graph
plt.figure(figsize=(15, 15))
_ = graph.visualize(layout="spring")

In [None]:
# Compile the graph with input questions
compiled_graph = graph.compile(
    questions=[
        "Can you answer the following question as accurately as possible? "
        "You suspect that your patient has an enlarged submandibular salivary gland. "
        "You expect the enlarged gland: \n "
        "A) to be palpable intraorally.. \n "
        "B) to be palpable extraorally. \n "
        "C) to be palpable both intra- and extraorally. \n "
        "D) only to be detectable by radiographical examination. \n "
        "Explain your answer, putting the answer in the form (X) at the end of your response."
    ]
)

In [None]:
# Execute the graph on the Helium server
result = helium.invoke(compiled_graph)

In [None]:
# Print the results
pprint.pprint(result)