In [None]:
import sys
from pathlib import Path
from typing import Optional
from trace_to_graph import TraceGraphBuilder
import json
from typing import List, Dict


class TraceGraphBuilder:
    def __init__(self, include_messages: bool = False):
        self.include_messages = include_messages
        self.nodes = []
        self.edges = []
        self.runs = []
    
    def load_trace(self, trace_path: str | Path) -> List[Dict]:
        """Load a trace JSON file."""
        with open(trace_path, 'r') as f:
            return json.load(f)

    def build_graph(self, trace_data: List[Dict]) -> None:
        """Build a graph from the trace data."""
        # for run in trace_data:
        #     self.runs.append(run)
        #     self.nodes.append(run)
        #     self.edges.append(run)
    

from collections import defaultdict
def build_graph(runs):
    graph = defaultdict(list)
    id_to_run = {r["id"]: r for r in runs}
    for r in runs:
        for parent in r.get("parent_run_ids", []):
            graph[parent].append(r["id"])
    return graph, id_to_run

def find_roots(runs):
    return [r["id"] for r in runs if not r.get("parent_run_ids")]

def traverse(graph, id_to_run, node_id, depth=0, chain=None):
    if chain is None:
        chain = []
    run = id_to_run[node_id]
    step = {
        "depth": depth,
        "name": run.get("name"),
        "run_type": run.get("run_type"),
        "input": run.get("inputs"),
        "output": run.get("outputs"),
        "start_time": run.get("start_time")
    }
    chain.append(step)
    for child in graph.get(node_id, []):
        traverse(graph, id_to_run, child, depth+1, chain)
    return chain

def chain_runs(runs):
    graph, id_to_run = build_graph(runs)
    roots = find_roots(runs)
    full_chain = []
    for root in roots:
        chain = traverse(graph, id_to_run, root)
        full_chain.extend(chain)
    return sorted(full_chain, key=lambda x: x["start_time"])

def dedupe_trace(chain):
    cleaned = []
    last_payload = None

    for step in chain:
        name = step["name"]
        inp = step.get("input")
        out = step.get("output")

        # drop nodes with no useful I/O
        if not inp and not out:
            continue

        # 4. collapse exact duplicate steps
        current_payload = (name, str(inp), str(out))
        if current_payload == last_payload:
            continue
        last_payload = current_payload

        cleaned.append(step)

    return cleaned

def summarize_chain(chain):
    chain = dedupe_trace(chain)
    for i,step in enumerate(chain):
        indent = "  " * step["depth"]
        if i == 0:
            if step.get("input"):
                print(f"> input: {step['input']}")

        print(f"{indent}- {step['name']} ({step['run_type']})")
        if step.get("input"):
            msg = str(step["input"])[:120]
            print(f"{indent}  ↳ input: {msg}")
        if step.get("output"):
            msg = str(step["output"])[:120]
            print(f"{indent}  ↳ output: {msg}")

        
        if i == len(chain) - 1:
            if step.get("output"):
                print(f"> output: {step['output']}")



def visualize_trace(trace_path: str | Path, output_path: Optional[str] = None, 
                   include_messages: bool = False, interactive: bool = True) -> None:
    """
    Load a trace file and visualize it as a graph.
    
    Args:
        trace_path: Path to the trace JSON file
        output_path: Path to save the visualization (optional)
        include_messages: Whether to include individual messages as nodes
        interactive: If True, use pyvis for interactive HTML. If False, use matplotlib.
    """
    builder = TraceGraphBuilder(include_messages=include_messages)
    trace_data = builder.load_trace(trace_path)
    builder.build_graph(trace_data)
    summarize_chain(chain_runs(trace_data))
    # for t in trace_data:
    #     step = extract_step(t)
    #     for k, v in step.items():
    #         if type(v) == list:
    #             print(k)
    #             for i in v:
    #                 print(i)
    #         else:
    #             print(k, v)
    #     print("Next step --------------------------------")
    # print ("--------------------------------")
    
    print(f"Loaded trace with {len(builder.runs)} runs")
    print(f"Created {len(builder.nodes)} nodes and {len(builder.edges)} edges")

In [None]:
import os
from typing import List, Dict, Any, Optional, Tuple
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate

# Import the model helper (adjust import path as needed)
try:
    from holistic_ai_bedrock import get_chat_model
    HAS_BEDROCK = True
except ImportError:
    try:
        from langchain_openai import ChatOpenAI
        HAS_BEDROCK = False
    except ImportError:
        print("Warning: No LLM provider found. Install holistic_ai_bedrock or langchain_openai")
        HAS_BEDROCK = None


def extract_messages_from_trace(trace_data: List[Dict]) -> List[Dict[str, Any]]:
    """
    Extract all messages from trace data, including from inputs and outputs.
    Returns a list of message dictionaries with metadata.
    """
    messages = []
    
    for run in trace_data:
        run_id = run.get("id")
        run_name = run.get("name", "Unknown")
        run_type = run.get("run_type", "unknown")
        start_time = run.get("start_time")
        
        # Extract messages from inputs
        inputs = run.get("inputs", {})
        if isinstance(inputs, dict):
            # Handle different input structures
            if "messages" in inputs:
                msgs = inputs["messages"]
                # Handle nested list structure
                if isinstance(msgs, list) and len(msgs) > 0:
                    if isinstance(msgs[0], list):
                        msgs = msgs[0]  # Unwrap nested list
                    for msg in msgs:
                        if isinstance(msg, dict):
                            messages.append({
                                "message": msg,
                                "source": "input",
                                "run_id": run_id,
                                "run_name": run_name,
                                "run_type": run_type,
                                "timestamp": start_time
                            })
            elif "input" in inputs and isinstance(inputs["input"], dict):
                # Handle nested input structure
                nested_input = inputs["input"]
                if "messages" in nested_input:
                    msgs = nested_input["messages"]
                    if isinstance(msgs, list) and len(msgs) > 0:
                        if isinstance(msgs[0], list):
                            msgs = msgs[0]
                        for msg in msgs:
                            if isinstance(msg, dict):
                                messages.append({
                                    "message": msg,
                                    "source": "input",
                                    "run_id": run_id,
                                    "run_name": run_name,
                                    "run_type": run_type,
                                    "timestamp": start_time
                                })
        
        # Extract messages from outputs
        outputs = run.get("outputs", {})
        if isinstance(outputs, dict):
            if "messages" in outputs:
                msgs = outputs["messages"]
                if isinstance(msgs, list) and len(msgs) > 0:
                    if isinstance(msgs[0], list):
                        msgs = msgs[0]
                    for msg in msgs:
                        if isinstance(msg, dict):
                            messages.append({
                                "message": msg,
                                "source": "output",
                                "run_id": run_id,
                                "run_name": run_name,
                                "run_type": run_type,
                                "timestamp": start_time
                            })
    
    return messages


def format_message_for_analysis(msg_dict: Dict[str, Any]) -> str:
    """Format a message dictionary into a readable string for LLM analysis."""
    msg = msg_dict["message"]
    msg_type = msg.get("type", "unknown")
    content = msg.get("content", "")
    
    # Handle tool calls
    tool_calls = msg.get("tool_calls", [])
    if tool_calls:
        tool_info = []
        for tc in tool_calls:
            tool_name = tc.get("name", "unknown")
            tool_args = tc.get("args", {})
            tool_info.append(f"{tool_name}({tool_args})")
        content += f" [Tool calls: {', '.join(tool_info)}]"
    
    # Handle additional kwargs
    additional_kwargs = msg.get("additional_kwargs", {})
    if additional_kwargs.get("tool_calls"):
        tool_info = []
        for tc in additional_kwargs["tool_calls"]:
            func = tc.get("function", {})
            tool_name = func.get("name", "unknown")
            tool_args = func.get("arguments", "{}")
            tool_info.append(f"{tool_name}({tool_args})")
        content += f" [Tool calls: {', '.join(tool_info)}]"
    
    # Build formatted string
    formatted = f"[{msg_dict['source'].upper()}] {msg_type.upper()}"
    if msg_dict.get("run_name"):
        formatted += f" from {msg_dict['run_name']}"
    formatted += f": {content}"
    
    return formatted


def find_issue_origin(
    trace_data: List[Dict] | Dict[str, Any],
    user_question: str,
    llm_model: Optional[Any] = None,
    max_messages_to_check: int = 20,
    confidence_threshold: float = 0.7
) -> List[Tuple[Dict[str, Any], float, str]]:
    """
    Analyze trace messages in reverse order to find the origin of an issue.
    
    Args:
        trace_data: Either a list of trace run dictionaries, or a dict with 'messages' key
                   containing LangChain message objects
        user_question: The user's question about an issue (e.g., "why did they choose this room")
        llm_model: Optional LLM model. If None, will try to use get_chat_model() or ChatOpenAI
        max_messages_to_check: Maximum number of messages to check (starting from most recent)
        confidence_threshold: Minimum confidence score (0-1) to consider a message relevant
    
    Returns:
        List of tuples: (message_dict, confidence_score, explanation)
        Sorted by confidence score (highest first)
    """
    # Initialize LLM if not provided
    if llm_model is None:
        llm_model = get_chat_model("claude-3-5-sonnet")
        
    
    # Extract messages
    if isinstance(trace_data, dict) and "messages" in trace_data:
        # Handle LangChain message objects
        messages = []
        for i, msg in enumerate(trace_data["messages"]):
            if isinstance(msg, BaseMessage):
                msg_dict = {
                    "message": {
                        "type": msg.__class__.__name__.replace("Message", "").lower(),
                        "content": msg.content if hasattr(msg, "content") else str(msg),
                        "tool_calls": getattr(msg, "tool_calls", []),
                        "additional_kwargs": getattr(msg, "additional_kwargs", {})
                    },
                    "source": "output",
                    "run_id": f"msg_{i}",
                    "run_name": "trace",
                    "run_type": "message",
                    "timestamp": None,
                    "index": i
                }
                messages.append(msg_dict)
    else:
        # Handle trace data structure
        messages = extract_messages_from_trace(trace_data)
        # Add index for reverse ordering
        for i, msg in enumerate(messages):
            msg["index"] = i
    
    if not messages:
        return []
    
    # Sort by timestamp (most recent first) or by index (reverse)
    messages_sorted = sorted(
        messages,
        key=lambda x: (x.get("timestamp") or ""),
        reverse=True
    )[:max_messages_to_check]
    
    # If no timestamps, reverse by index
    if not any(m.get("timestamp") for m in messages_sorted):
        messages_sorted = list(reversed(messages))[:max_messages_to_check]
    
    # Create evaluation prompt
    evaluation_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are an expert at analyzing agent execution traces to find the root cause of issues.

Your task is to evaluate whether a specific message in an execution trace is the origin or cause of a user-reported issue.

For each message, provide:
1. A confidence score from 0.0 to 1.0 indicating how likely this message is the origin of the issue
2. A brief explanation (1-2 sentences) of why this message is or isn't relevant

Consider:
- Does this message contain the decision, action, or information that led to the issue?
- Is this message where the problematic choice was made?
- Does this message show incorrect reasoning or data that caused the problem?

Respond in JSON format:
{"confidence": 0.0-1.0, "explanation": "brief explanation"}"""),
        ("human", """User Question: {user_question}

Message to evaluate:
{message_content}

Context: This is message {message_index} of {total_messages} (evaluating in reverse chronological order, most recent first).

Provide your evaluation as JSON with 'confidence' and 'explanation' fields.""")
    ])
    
    results = []
    
    # Evaluate each message
    for idx, msg_dict in enumerate(messages_sorted):
        message_content = format_message_for_analysis(msg_dict)
        
        try:
            # Get LLM evaluation
            chain = evaluation_prompt | llm_model
            response = chain.invoke({
                "user_question": user_question,
                "message_content": message_content,
                "message_index": idx + 1,
                "total_messages": len(messages_sorted)
            })
            
            # Parse response
            response_text = response.content if hasattr(response, "content") else str(response)
            
            # Try to extract JSON from response
            import json
            import re
            
            # Look for JSON in the response
            json_match = re.search(r'\{[^{}]*"confidence"[^{}]*\}', response_text, re.DOTALL)
            if json_match:
                try:
                    eval_result = json.loads(json_match.group())
                    confidence = float(eval_result.get("confidence", 0.0))
                    explanation = eval_result.get("explanation", "No explanation provided")
                except:
                    # Fallback: try to extract confidence from text
                    conf_match = re.search(r'confidence["\s:]+([0-9.]+)', response_text, re.IGNORECASE)
                    confidence = float(conf_match.group(1)) if conf_match else 0.0
                    explanation = response_text[:200]
            else:
                # Fallback parsing
                conf_match = re.search(r'confidence["\s:]+([0-9.]+)', response_text, re.IGNORECASE)
                confidence = float(conf_match.group(1)) if conf_match else 0.0
                explanation = response_text[:200]
            
            if confidence >= confidence_threshold:
                results.append((msg_dict, confidence, explanation))
                
        except Exception as e:
            print(f"Error evaluating message {idx + 1}: {e}")
            continue
    
    # Sort by confidence (highest first)
    results.sort(key=lambda x: x[1], reverse=True)
    
    return results


def print_issue_analysis(
    trace_data: List[Dict] | Dict[str, Any],
    user_question: str,
    llm_model: Optional[Any] = None,
    max_results: int = 5
) -> None:
    """
    Convenience function to find and print issue origins.
    """
    results = find_issue_origin(trace_data, user_question, llm_model)
    
    if not results:
        print(f"No messages found that are likely origins of: '{user_question}'")
        return
    
    print(f"\n{'='*80}")
    print(f"ISSUE ANALYSIS: {user_question}")
    print(f"{'='*80}\n")
    print(f"Found {len(results)} relevant message(s):\n")
    
    for i, (msg_dict, confidence, explanation) in enumerate(results[:max_results], 1):
        msg = msg_dict["message"]
        msg_type = msg.get("type", "unknown")
        content = msg.get("content", "")[:150]
        
        print(f"{i}. [Confidence: {confidence:.2f}]")
        print(f"   Type: {msg_type.upper()}")
        print(f"   Source: {msg_dict['source']} from {msg_dict.get('run_name', 'unknown')}")
        print(f"   Content: {content}...")
        print(f"   Explanation: {explanation}")
        print()


In [56]:
# Example usage with the example_output from Cell 1
# Analyze why the agent chose the community center

print_issue_analysis(
    example_output,
    user_question="why did they choose this room",
    max_results=3
)


ValueError: No LLM model available. Please provide one or set up API keys.

In [None]:
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage

example_output = {
    'messages': [
        HumanMessage(
            content='\nYou are a helpful Town Hall agent. Your goal is to process requests by thinking step-by-step and using the available tools. You must use the check_calendar, check_room_rules, assign_task, and final_answer tools. You will follow this pattern:\n\nHuman: [The user\'s request]\nAgent thought: [Your reasoning and plan]\nAgent action: [The tool you will use]\nObservation: [The result from the tool]\n... (repeat thought/action/observation) ...\nAgent thought: [Your final reasoning]\nAgent action: final_answer("[Your response to the human]")\n\n---\nHere are some examples:\n---\n\nHuman: I need to book the community center for a 20-person meeting on Dec 5th.\nAgent thought: The user wants to book \'community_center\' on \'Dec 5th\' for 20 people. First, I must check if the room is available on that date.\nAgent action: check_calendar(date="2025-12-05")\nObservation: The calendar shows this date is AVAILABLE.\nAgent thought: The date is available. Now I need to check the rules for the \'community_center\' to ensure a 20-person meeting is allowed.\nAgent action: check_room_rules(room_id="community_center")\nObservation: Rules for community_center are \'max_capacity: 50, no_food_allowed\'.\nAgent thought: The 20-person meeting is within the 50-person capacity. The request is valid. I can now book it and inform the user of the \'no_food_allowed\' rule.\nAgent action: final_answer("Booking confirmed for the community center on Dec 5th. Please note: no food is allowed.")\n\n---\nHuman: Please assign the \'Setup chairs\' ticket to Bob.\nAgent thought: I need to assign the task \'Setup chairs\' to \'Bob\'. I must use the assign_task tool.\nAgent action: assign_task(staff_name="Bob", task="Setup chairs")\nObservation: Action was not successful. Reason: \'Bob\' is marked as \'On-Leave\'.\nAgent thought: I cannot assign the task to Bob because he is on leave. This is a failure[cite: 173]. I must inform the user that I cannot complete their request and state the reason.\nAgent action: final_answer("I\'m sorry, I cannot assign \'Setup chairs\' to Bob. My system shows he is \'On-Leave\'.")\nHuman: I would like to book the community center for a 30-person event on December 10th.', 
            additional_kwargs={}, response_metadata={}, id='2aca527f-36ec-4d17-a6a7-8daa84f44a9c'), 
        AIMessage(content='', additional_kwargs={}, response_metadata={}, 
            id='lc_run--0482436b-58f9-49e4-9fa7-b5825c1a51dd-0', 
            tool_calls=[{'name': 'check_calendar', 'args': {'date': '2023-12-10'}, 
            'id': 'toolu_bdrk_01726hSPHm9cVQ17LqHCqfPp', 'type': 'tool_call'}]), 
        ToolMessage(content='Observation:2023-12-10 is AVAILABLE', name='check_calendar', 
                    id='9e2ab60c-da09-4f96-a6a1-fe0efebe7ba4', tool_call_id='toolu_bdrk_01726hSPHm9cVQ17LqHCqfPp'), 
        AIMessage(content='', additional_kwargs={}, response_metadata={}, 
            id='lc_run--952a9fb0-8a9a-4544-8b33-f93c2d21a823-0', 
            tool_calls=[{'name': 'check_room_rules', 'args': {'room_id': 'community_center'}, 
                        'id': 'toolu_bdrk_01LYBoqJQsxMGycuDsFSRuLZ', 'type': 'tool_call'}]), 
        ToolMessage(content="Observation: Rules for community_center are 'max_capacity: 50, no_food_allowed'.", 
                    name='check_room_rules', id='b2c5834c-6ff8-4d57-a912-73267d178425', 
                    tool_call_id='toolu_bdrk_01LYBoqJQsxMGycuDsFSRuLZ'), 
        AIMessage(content='', additional_kwargs={}, response_metadata={}, 
                    id='lc_run--e177284f-30c5-4e8a-8700-e9f9e11556a9-0', 
                    tool_calls=[{'name': 'assign_task', 'args': {'staff_name': 'John', 'task': 'Prepare community center for December 10th event - 30 people'}, 
                    'id': 'toolu_bdrk_01QrZTgBShYFdnVtJRQcqYqY', 'type': 'tool_call'}]), 
        ToolMessage(
            content="Action successful: Task 'Prepare community center for December 10th event - 30 people' assigned to John.", name='assign_task', id='8cba59a0-fb01-4c23-ae81-7200bc3b4ac3', tool_call_id='toolu_bdrk_01QrZTgBShYFdnVtJRQcqYqY'), 
        AIMessage(content='', additional_kwargs={}, response_metadata={}, 
            id='lc_run--4974f8ea-df0f-44f4-a97e-9faf1150adb3-0', 
            tool_calls=[{'name': 'final_answer', 'args': {'response': "Your booking for the community center on December 10th has been confirmed. The venue can accommodate your 30-person event as it's within our maximum capacity of 50 people. Please note that no food is allowed in the community center. I've assigned staff to prepare the venue for your event."}, 'id': 'toolu_bdrk_01DpnTN47Q2e45qptFzfhtgp', 'type': 'tool_call'}]), 
        ToolMessage(content="Final Response to User: Your booking for the community center on December 10th has been confirmed. The venue can accommodate your 30-person event as it's within our maximum capacity of 50 people. Please note that no food is allowed in the community center. I've assigned staff to prepare the venue for your event.", 
            name='final_answer', 
            id='754dab08-0270-4b95-91a2-e5e796dad0a1', 
            tool_call_id='toolu_bdrk_01DpnTN47Q2e45qptFzfhtgp'), 
        AIMessage(
            content="{'content': [], 'usage': {'input_tokens': 1781, 'output_tokens': 3, 'total_tokens': 1784}, 'metadata': {'team_id': 'team_the_great_hack_2025_046', 'model': 'us.anthropic.claude-3-5-sonnet-20241022-v2:0', 'cost_usd': 0.005388, 'latency_ms': 906.67, 'remaining_quota': {'requests_today': 35, 'tokens_today': 186752, 'llm_cost': 0.628632, 'gpu_cost': 0.0, 'total_cost': 0.628632, 'budget_limit': 50.0, 'remaining_budget': 49.371368, 'budget_usage_percent': 1.257264}}}", 
            additional_kwargs={}, 
            response_metadata={}, 
            id='lc_run--ec5829ea-602c-4dbc-b855-3cdba4340b46-0')
        ], 
    # 'structured_response': TechAnalysis(
    #     thought="I'll format this API response data as a structured JSON object. The data contains usage statistics, metadata about the team, model, costs, latency, and quota information. I'll present it in a properly indented, readable JSON format.", 
    #     action='{\n  "content": [],\n  "usage": {\n    "input_tokens": 1781,\n    "output_tokens": 3,\n    "total_tokens": 1784\n  },\n  "metadata": {\n    "team_id": "team_the_great_hack_2025_046",\n    "model": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",\n    "cost_usd": 0.005388,\n    "latency_ms": 906.67,\n    "remaining_quota": {\n      "requests_today": 35,\n      "tokens_today": 186752,\n      "llm_cost": 0.628632,\n      "gpu_cost": 0.0,\n      "total_cost": 0.628632,\n      "budget_limit": 50.0,\n      "remaining_budget": 49.371368,\n      "budget_usage_percent": 1.257264\n    }\n  }\n}')
    }
from termcolor import colored

def print_message_summary(msg):
    """
    Pretty-print a summary of a message object with colored type and content/redirect.
    Handles HumanMessage, AIMessage, ToolMessage.
    """
    msg_type = type(msg).__name__
    content = getattr(msg, 'content', '')
    if msg_type == "HumanMessage":
        tcolor = 'cyan'
        important = content.strip().split("\n")[-1][:100]  # Last line of prompt, max 100 chars
    elif msg_type == "AIMessage":
        tcolor = 'green'
        # If tool call(s), summarize; else, show content
        tool_calls = getattr(msg, 'tool_calls', None)
        if tool_calls and len(tool_calls) > 0:
            tc = tool_calls[0]
            important = f"Tool call: {tc.get('name', '')}({tc.get('args', {})})"
        else:
            important = content.strip()[:100]
    elif msg_type == "ToolMessage":
        tcolor = 'magenta'
        toolname = getattr(msg, "name", "")
        if content.startswith("Observation"):
            important = content.strip()[:120]
        else:
            important = f"{toolname}: {content.strip()[:100]}"
    else:
        tcolor='yellow'
        important = str(content)[:100]
    print(colored(f"[{msg_type:12s}] ", tcolor, attrs=["bold"]), end='')
    print(important)

# Example usage:
for m in example_output['messages']:
    print_message_summary(m)


[1m[36m[HumanMessage] [0mHuman: I would like to book the community center for a 30-person event on December 10th.
[1m[32m[AIMessage   ] [0mTool call: check_calendar({'date': '2023-12-10'})
[1m[35m[ToolMessage ] [0mObservation:2023-12-10 is AVAILABLE
[1m[32m[AIMessage   ] [0mTool call: check_room_rules({'room_id': 'community_center'})
[1m[35m[ToolMessage ] [0mObservation: Rules for community_center are 'max_capacity: 50, no_food_allowed'.
[1m[32m[AIMessage   ] [0mTool call: assign_task({'staff_name': 'John', 'task': 'Prepare community center for December 10th event - 30 people'})
[1m[35m[ToolMessage ] [0massign_task: Action successful: Task 'Prepare community center for December 10th event - 30 people' assigned to J
[1m[32m[AIMessage   ] [0mTool call: final_answer({'response': "Your booking for the community center on December 10th has been confirmed. The venue can accommodate your 30-person event as it's within our maximum capacity of 50 people. Please note that

In [None]:
HERE

In [48]:

# Default: visualize the very simple trace (Jupyter-friendly)
import os

try:
    # Try to use current notebook working directory
    current_dir = Path().resolve()
    trace_dir = current_dir.parent.parent / "track_b_glass_box" / "traces"
except Exception as e:
    print("Could not determine path to traces directory:", e)
    trace_dir = Path(".")
trace_file = trace_dir / "trace-very-simple.json"

if trace_file.exists():
    print(f"Visualizing {trace_file}")
    visualize_trace(trace_file, interactive=True)
else:
    print(f"Trace file {trace_file} does not exist")


Visualizing /home/dennisushi/personal/hackthon-2025/track_b_glass_box/traces/trace-very-simple.json
> input: {'messages': [{'type': 'human', 'content': 'What is a GPU?\n'}]}
- ReAct Agent (chain)
  ↳ input: {'messages': [{'type': 'human', 'content': 'What is a GPU?\n'}]}
  ↳ output: {'messages': [{'content': 'What is a GPU?\n', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'id': '
  - call_model (chain)
    ↳ input: {'input': {'messages': [{'content': 'What is a GPU?\n', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human
    ↳ output: {'messages': [{'content': 'A GPU, or Graphics Processing Unit, is a specialized electronic circuit designed to accelerat
    - ChatOpenAI (llm)
      ↳ input: {'messages': [[{'lc': 1, 'type': 'constructor', 'id': ['langchain', 'schema', 'messages', 'SystemMessage'], 'kwargs': {'
      ↳ output: {'generations': [{'text': 'A GPU, or Graphics Processing Unit, is a specialized electronic circuit designed to accelerat
    - r

In [None]:
from collections import defaultdict
from typing import List, Dict, Any


def build_interaction_tree(trace_data: List[Dict]) -> Dict[str, Any]:
    """Build a hierarchical interaction tree with actor labels (human, ai, tool, agent)."""

    # Build graph relationships
    graph = defaultdict(list)
    id_to_run = {r["id"]: r for r in trace_data}
    for r in trace_data:
        for parent in r.get("parent_run_ids", []):
            graph[parent].append(r["id"])

    def classify_actor(run: Dict) -> str:
        """Return 'human', 'ai', 'tool', or 'agent'."""
        rt = run.get("run_type", "")
        name = run.get("name", "")

        # Human message is always inside inputs/messages
        if run.get("inputs"):
            msgs = run["inputs"].get("messages", [])
            if msgs:
                # Handle nested list structure: [[{...}]] or [{...}]
                first_msg = msgs[0]
                # If first element is a list, unwrap it
                if isinstance(first_msg, list) and len(first_msg) > 0:
                    first_msg = first_msg[0]
                # Now check if it's a dict with type "human"
                if isinstance(first_msg, dict) and first_msg.get("type") == "human":
                    return "human"

        # LLM model
        if rt == "llm" or "ChatOpenAI" in name or "OpenAI" in name:
            return "ai"

        # Tool call
        if rt == "tool":
            return "tool"

        # Everything else is orchestrator
        return "agent"

    def short(obj):
        if not obj:
            return None
        s = json.dumps(obj)
        return s[:200] + ("..." if len(s) > 200 else "")

    def build_node(run_id: str, depth=0):
        run = id_to_run[run_id]
        node = {
            "id": run_id,
            "name": run.get("name"),
            "type": run.get("run_type"),
            "actor": classify_actor(run),     # ← KEY LINE
            "depth": depth,
            "input": short(run.get("inputs")),
            "output": short(run.get("outputs")),
            "children": []
        }
        for child_id in graph.get(run_id, []):
            node["children"].append(build_node(child_id, depth + 1))
        return node

    # Find roots (no parent)
    roots = [r["id"] for r in trace_data if not r.get("parent_run_ids")]
    if not roots:
        raise ValueError("No root nodes found in trace.")

    # If multiple roots, wrap them; otherwise return the only root
    return (
        build_node(roots[0])
        if len(roots) == 1
        else {"roots": [build_node(r) for r in roots]}
    )

In [46]:
include_messages = True
trace_path = trace_file
builder = TraceGraphBuilder(include_messages=include_messages)
trace_data = builder.load_trace(trace_path)
build_interaction_tree(trace_data)

{'id': '1f06b75b-3a54-6918-b0a3-d23cf007ec77',
 'name': 'ReAct Agent',
 'type': 'chain',
 'actor': 'human',
 'depth': 0,
 'input': '{"messages": [{"type": "human", "content": "What is a GPU?\\n"}]}',
 'output': '{"messages": [{"content": "What is a GPU?\\n", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "id": "e7849012-9dc8-4cb5-ab9a-7d3413943077", "example": false}, {"content": "A GPU, or...',
 'children': [{'id': 'dcd21300-e41c-43f1-930c-88e707a5f13c',
   'name': 'call_model',
   'type': 'chain',
   'actor': 'agent',
   'depth': 1,
   'input': '{"input": {"messages": [{"content": "What is a GPU?\\n", "additional_kwargs": {}, "response_metadata": {}, "type": "human", "id": "e7849012-9dc8-4cb5-ab9a-7d3413943077", "example": false}], "is_last_st...',
   'output': '{"messages": [{"content": "A GPU, or Graphics Processing Unit, is a specialized electronic circuit designed to accelerate the processing of images and graphics. It is highly efficient at performing pa...',