# Tool Sequence Analysis

This notebook analyzes the sequence of tool calls in RCA tasks. It ingests experiment results, extracts tool usage patterns, and visualizes the flow of tool calls using a Sankey diagram.

In [None]:
import os
import json
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from pypalettes import load_cmap
from dotenv import load_dotenv

# Load environment variables
root_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
load_dotenv(os.path.join(root_dir, '.env'))

RESULTS_DIR = os.environ.get("RESULTS_PATH")
print(f"Results Directory: {RESULTS_DIR}")

# Output directory for static plots
PLOTS_DIR = "analysis_plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
print(f"Plots will be saved to: {PLOTS_DIR}")

# Agent Configuration Symbol Mapping
AGENT_SYMBOL_MAP = {
    'A': r'$\bigstar$ A',      # Plain ReAct
    'B': r'$\bigstar$ B',      # Plain ReAct
    'C': r'$\bigstar$ C',      # Plain ReAct
    'D': r'$\clubsuit$ D',     # Conservative
    'E': r'$\clubsuit$ E',     # Conservative
    'F': r'$\blacksquare$ F',  # Tool-free
    'G': r'$\blacksquare$ G',  # Tool-free
    'H': r'$\blacktriangle$ H', # Strict Supervisor
    'I': r'$\blacktriangle$ I', # Strict Supervisor
    'J': r'$\spadesuit$ J',    # Improved ReAct
    'K': r'$\spadesuit$ K',    # Improved ReAct
    'L': r'$\spadesuit$ L',    # Improved ReAct
}

def get_agent_label(agent_id):
    """Get agent label with LaTeX symbol."""
    return AGENT_SYMBOL_MAP.get(agent_id, agent_id)

## Select Experiments

In [None]:
if RESULTS_DIR and os.path.exists(RESULTS_DIR):
    experiment_batches = [d for d in os.listdir(RESULTS_DIR) if os.path.isdir(os.path.join(RESULTS_DIR, d))]
    experiment_batches.sort()
    
    print("Available Experiment Batches:")
    for i, d in enumerate(experiment_batches, 1):
        dir_path = os.path.join(RESULTS_DIR, d)
        json_count = sum(1 for f in os.listdir(dir_path) if f.endswith('.json'))
        print(f"{i}) {d} ({json_count} results)")
else:
    print("Results directory not found or empty.")
    experiment_batches = []

In [None]:
# Interactive Selection (Default to last if no input mechanism in non-interactive run)
selected_indices = input("Enter experiment batch IDs (comma-separated, e.g., 1,2): ")

selected_dirs = []
try:
    if selected_indices.strip():
        indices = [int(x.strip()) - 1 for x in selected_indices.split(',')]
        for idx in indices:
            if 0 <= idx < len(experiment_batches):
                selected_dirs.append(os.path.join(RESULTS_DIR, experiment_batches[idx]))
            else:
                print(f"Warning: Index {idx+1} out of range.")
    else:
        print("No input provided.")
except ValueError:
    print("Invalid input.")

if not selected_dirs and experiment_batches:
    print("Defaulting to the most recent batch.")
    selected_dirs = [os.path.join(RESULTS_DIR, experiment_batches[-1])]
    
print(f"Selected {len(selected_dirs)} folders for analysis.")

## Data Ingestion & Processing

In [None]:
def extract_tool_calls(message_history):
    """Extracts a sequential list of tool names from message history."""
    tool_sequence = []
    if not isinstance(message_history, list):
        return tool_sequence
        
    for message in message_history:
        # Check if message is an AI message with tool calls
        if isinstance(message, dict) and message.get("type") == "AIMessage":
            tool_calls = message.get("tool_calls", [])
            for tc in tool_calls:
                if isinstance(tc, dict) and "name" in tc:
                    tool_sequence.append(tc["name"])
    return tool_sequence

records = []

for folder in selected_dirs:
    if not os.path.exists(folder):
        continue
        
    files = [f for f in os.listdir(folder) if f.endswith('.json')]
    print(f"Processing {folder}... Found {len(files)} JSON files.")
    
    for file_name in files:
        file_path = os.path.join(folder, file_name)
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            agent_id = data.get("agent_id", "Unknown")
            app_name = data.get("app_name") or data.get("testbed", {}).get("application_name", ["Unknown"])[0]
            fault_type = data.get("testbed", {}).get("fault_name", "Unknown")
            
            # Evaluation Data
            evaluation = data.get("evaluation", {})
            eval_detection = evaluation.get("detection")
            eval_localization = evaluation.get("localization")
            eval_rca_score = evaluation.get("rca_score")

            # Experiment Stats
            stats = data.get("stats", {})
            execution_time = stats.get("execution_time_seconds")
            total_tokens = stats.get("total_tokens")

            # Process RCA Analyses
            rca_list = data.get("rca_analyses_list", [])
            if not rca_list:
                continue # Skip if no analysis
                
            for i, analysis in enumerate(rca_list, 1):
                message_history = analysis.get("message_history", [])
                tool_sequence = extract_tool_calls(message_history)
                
                # Create record
                row = {
                    "Agent ID": agent_id,
                    "Application Name": app_name,
                    "Fault Type": fault_type,
                    "Eval Detection": eval_detection,
                    "Eval Localization": eval_localization,
                    "Eval RCA Score": eval_rca_score,
                    "Execution Time (s)": execution_time,
                    "Total Tokens": total_tokens,
                    "RCA Task Index": f"RCA Task {i}",
                    "Total Tools": len(tool_sequence)
                }
                
                # Add Tool 1 to Tool 10
                for t_idx in range(10):
                    col_name = f"Tool {t_idx + 1}"
                    if t_idx < len(tool_sequence):
                        row[col_name] = tool_sequence[t_idx]
                    else:
                        row[col_name] = None # Or "End"
                
                records.append(row)
                
        except Exception as e:
            print(f"Error processing {file_name}: {e}")

df = pd.DataFrame(records)
print(f"Created DataFrame with {len(df)} rows.")
df.head()

## Visualization: Tool Sequence Sankey Diagram

In [None]:
# Prepare data for Sankey
# We strictly want sequences: Tool 1 -> Tool 2 -> Tool 3 ...

from pypalettes import load_cmap
import numpy as np

max_depth = 8
min_link_count = 3  # drop very rare transitions to declutter
links = []

# Initialize node mapping
# Format: "{StepIndex}_{ToolName}" -> ID
node_label_map = {}
node_labels = []
node_hovertext = []
node_steps = []
node_counter = 0

def get_node_id(step, tool_name):
    """Return node id, creating node with hover text when new."""
    global node_counter
    label = f"Step {step}: {tool_name}"
    if label not in node_label_map:
        node_label_map[label] = node_counter
        node_labels.append(tool_name)  # Only tool name is shown on the node
        node_hovertext.append(label)   # Full step + tool for hover clarity
        node_steps.append(step)
        node_counter += 1
    return node_label_map[label]

source_indices = []
target_indices = []
values = []
link_colors = []
link_tool_names = []

# Aggregate transitions
transitions = {}

for _, row in df.iterrows():
    # Walk through tools 1 to 10
    for i in range(1, max_depth):
        current_col = f"Tool {i}"
        next_col = f"Tool {i+1}"
        
        current_tool = row[current_col]
        next_tool = row[next_col]
        
        # Validate sequence
        if pd.isna(current_tool):
            break  # End of sequence
        
        if pd.isna(next_tool):
            # Transition to "End" node for this step
            next_tool_label = "End"
        else:
            next_tool_label = next_tool
            
        # Key: (Step_i, ToolA, Step_i+1, ToolB)
        key = (i, current_tool, i+1, next_tool_label)
        transitions[key] = transitions.get(key, 0) + 1
        
        if next_tool_label == "End":
            break

# Build Sankey Arrays (filter by min_link_count to reduce clutter)
filtered_items = [((step_from, tool_from, step_to, tool_to), count) for (step_from, tool_from, step_to, tool_to), count in transitions.items() if count >= min_link_count]

if not filtered_items:
    print(f"No links meet min_link_count={min_link_count}. Lower the threshold to see data.")

for (step_from, tool_from, step_to, tool_to), count in filtered_items:
    src_id = get_node_id(step_from, tool_from)
    tgt_id = get_node_id(step_to, tool_to)
    
    source_indices.append(src_id)
    target_indices.append(tgt_id)
    values.append(count)
    link_tool_names.append(tool_from)

# Build color palette per tool name using color-blind friendly palette (trim to avoid over-coloring)
raw_cmap = load_cmap("Color_Blind")
max_colors = 8
sample_points = np.linspace(0, 1, max_colors)
palette_rgba = raw_cmap(sample_points)  # returns RGBA
palette = [f"rgba({int(r*255)},{int(g*255)},{int(b*255)},0.85)" for r, g, b, _ in palette_rgba]
palette_link = [f"rgba({int(r*255)},{int(g*255)},{int(b*255)},0.4)" for r, g, b, _ in palette_rgba]
color_map = {}
color_map_link = {}
unique_tools = pd.Index(node_labels).unique().dropna().to_list()
for idx, tool in enumerate(unique_tools):
    color_map[tool] = palette[idx % len(palette)]
    color_map_link[tool] = palette_link[idx % len(palette_link)]
color_map.setdefault("End", "rgba(158,158,158,0.85)")
color_map_link.setdefault("End", "rgba(158,158,158,0.35)")

def map_color(tool_name: str) -> str:
    return color_map.get(tool_name, "rgba(136,136,136,0.85)")

def map_color_link(tool_name: str) -> str:
    return color_map_link.get(tool_name, "rgba(136,136,136,0.35)")

node_colors = [map_color(tool) for tool in node_labels]
link_colors = [map_color_link(tool_name) for tool_name in link_tool_names]

# Position nodes by step along the x-axis for readability
# Force stronger positioning by spacing steps more explicitly
unique_steps_sorted = sorted(set(node_steps)) if node_steps else []
if len(unique_steps_sorted) > 1:
    # Use more explicit spacing to prevent Plotly from repositioning
    step_positions = {step: (step - min(unique_steps_sorted)) / (max(unique_steps_sorted) - min(unique_steps_sorted)) 
                     for step in unique_steps_sorted}
else:
    step_positions = {step: 0.0 for step in unique_steps_sorted}

node_x = [step_positions.get(step, 0.0) for step in node_steps]

# Statistics for subtitle
total_sequences = len(df)
total_transitions = sum(values)
unique_tool_count = len(unique_tools)

print(f"Step range: {min(unique_steps_sorted) if unique_steps_sorted else 'N/A'} to {max(unique_steps_sorted) if unique_steps_sorted else 'N/A'}")
print(f"Total sequences analyzed: {total_sequences}")
print(f"Total transitions: {total_transitions}")
print(f"Unique tools: {unique_tool_count}")

# Create Plot with enhanced styling
fig = go.Figure(data=[go.Sankey(
    arrangement='snap',  # Force nodes to snap to their assigned positions
    node = dict(
      pad = 18,
      thickness = 22,
      line = dict(color = "rgba(50,50,50,0.8)", width = 0.8),
      label = node_labels,
      color = node_colors,
      customdata = node_hovertext,
      x = node_x,
      hovertemplate = "<b>%{customdata}</b><extra></extra>"
    ),
    link = dict(
      source = source_indices,
      target = target_indices,
      value = values,
      color = link_colors,
      hovertemplate = "<b>%{source.label}</b> → <b>%{target.label}</b><br>Occurrences: %{value}<extra></extra>"
  ))])

# Enhanced layout with publication-quality styling (rectangular format)
fig.update_layout(
    title={
        'text': f"<b>Tool Call Sequence Analysis in RCA Tasks</b>",
        'x': 0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(size=16, family="Arial, sans-serif", color="#2c3e50")
    },
    font=dict(
        family="Arial, sans-serif",
        size=12,
        color="#34495e"
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    hovermode="closest",
    margin=dict(l=20, r=20, t=80, b=80),
    height=750,
    width=1200
)

# Add step labels along the bottom with enhanced styling
for step, xpos in step_positions.items():
    fig.add_annotation(
        x=xpos,
        y=-0.10,
        xref="paper",
        yref="paper",
        showarrow=False,
        text=f"<b>Step {step}</b>",
        font=dict(size=13, family="Arial, sans-serif", color="#2c3e50"),
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="rgba(44,62,80,0.3)",
        borderwidth=1,
        borderpad=4
    )

# Add subtle grid lines for steps (visual guide)
for step, xpos in step_positions.items():
    fig.add_shape(
        type="line",
        x0=xpos, y0=0, x1=xpos, y1=1,
        xref="paper", yref="paper",
        line=dict(color="rgba(189,195,199,0.2)", width=1, dash="dot")
    )


fig.show()

# Save Static Images (PDF and PNG) with high quality settings
try:
    # Save as PDF (vector format, best for papers)
    pdf_file = os.path.join(PLOTS_DIR, "tool_sequence_sankey.pdf")
    fig.write_image(pdf_file, format="pdf", width=1200, height=600)
    print(f"✓ Saved publication-quality PDF to {pdf_file}")
    
    # Also save as PNG for presentations/web
    png_file = os.path.join(PLOTS_DIR, "tool_sequence_sankey.png")
    fig.write_image(png_file, format="png", width=1200, height=600, scale=2)
    print(f"✓ Saved high-resolution PNG to {png_file}")
except Exception as e:
    print(f"⚠ Could not save PDF. Error: {e}")


## Tool Distribution by Step

Bar chart showing the percentage distribution of tools used at each diagnostic step, with color-coded tools and legend.

In [None]:
# Tool Distribution by Step using Matplotlib
import matplotlib.pyplot as plt
from pypalettes import load_cmap
import numpy as np

# Build color palette (same as Sankey)
raw_cmap = load_cmap("Color_Blind")
max_colors = 12
sample_points = np.linspace(0, 1, max_colors)
palette_rgba = raw_cmap(sample_points)
palette_list = [(r, g, b) for r, g, b, _ in palette_rgba]

# Collect tool distribution by step
tool_step_data = {}  # {step: {tool: count}}

for _, row in df.iterrows():
    for step in range(1, max_depth):
        col_name = f"Tool {step}"
        tool = row[col_name]
        
        if pd.notna(tool):
            if step not in tool_step_data:
                tool_step_data[step] = {}
            if tool not in tool_step_data[step]:
                tool_step_data[step][tool] = 0
            tool_step_data[step][tool] += 1
        else:
            break

# Convert to occurrence data (not percentages)
steps_sorted = sorted(tool_step_data.keys())
all_tools = set()
for step_tools in tool_step_data.values():
    all_tools.update(step_tools.keys())
all_tools = sorted(list(all_tools))

# Create color mapping for tools
tool_color_map = {}
for idx, tool in enumerate(all_tools):
    tool_color_map[tool] = palette_list[idx % len(palette_list)]

# Prepare data for stacked bar chart (raw counts)
data_by_tool = {tool: [] for tool in all_tools}
step_totals = []
for step in steps_sorted:
    total_count = sum(tool_step_data[step].values())
    step_totals.append(total_count)
    for tool in all_tools:
        count = tool_step_data[step].get(tool, 0)
        data_by_tool[tool].append(count)

# Create figure with matplotlib
fig, ax = plt.subplots(figsize=(14, 7), dpi=100)

# Create stacked bar chart
x_pos = np.arange(len(steps_sorted))
bottom = np.zeros(len(steps_sorted))

for tool in all_tools:
    values = data_by_tool[tool]
    ax.bar(x_pos, values, bottom=bottom, label=tool, 
           color=tool_color_map[tool], edgecolor='white', linewidth=0.5, alpha=0.9)
    bottom += np.array(values)

# Customize plot
ax.set_xlabel('Diagnostic Step', fontsize=13, fontweight='bold', color='#2c3e50')
ax.set_ylabel('Tool Call Count', fontsize=13, fontweight='bold', color='#2c3e50')
ax.set_title('Tool Distribution by Step\nNumber of tool calls at each diagnostic step', 
             fontsize=16, fontweight='bold', color='#2c3e50', pad=20)

ax.set_xticks(x_pos)
ax.set_xticklabels([f'Step {s}\n(n={step_totals[i]})' for i, s in enumerate(steps_sorted)], 
                     fontsize=11, color='#34495e')
ax.set_yticklabels(ax.get_yticklabels(), fontsize=11, color='#34495e')

# Add grid
ax.grid(axis='y', alpha=0.3, linestyle='--', linewidth=0.7, color='#bdc3c7')
ax.set_axisbelow(True)

# Customize spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_color('#34495e')
ax.spines['bottom'].set_color('#34495e')

# Add legend
ax.legend(title='Tools', fontsize=10, title_fontsize=11, 
          loc='upper left', bbox_to_anchor=(1.01, 1), framealpha=0.95,
          fancybox=True, shadow=True, edgecolor='#bdc3c7')

plt.tight_layout()

# Save as PNG and PDF
try:
    png_file = os.path.join(PLOTS_DIR, "tool_distribution_by_step.png")
    plt.savefig(png_file, dpi=200, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved tool distribution PNG to {png_file}")
except Exception as e:
    print(f"⚠ Could not save PNG. Error: {e}")

try:
    pdf_file = os.path.join(PLOTS_DIR, "tool_distribution_by_step.pdf")
    plt.savefig(pdf_file, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved tool distribution PDF to {pdf_file}")
except Exception as e:
    print(f"⚠ Could not save PDF. Error: {e}")

plt.show()


## Average Tool Calls Per Agent Type

In [None]:
# Calculate average tool calls per agent type
agent_tool_stats = df.groupby('Agent ID').agg({
    'Total Tools': ['mean', 'std', 'count', 'min', 'max']
}).round(2)

# Flatten column names
agent_tool_stats.columns = ['_'.join(col).strip() for col in agent_tool_stats.columns.values]
agent_tool_stats = agent_tool_stats.rename(columns={
    'Total Tools_mean': 'Average Tool Calls',
    'Total Tools_std': 'Std Dev',
    'Total Tools_count': 'Number of RCA Tasks',
    'Total Tools_min': 'Min Tool Calls',
    'Total Tools_max': 'Max Tool Calls'
})

# Filter out "Unknown" agent type
agent_tool_stats = agent_tool_stats[agent_tool_stats.index != "Unknown"]

# Order in ascending order (smallest to highest)
agent_tool_stats = agent_tool_stats.sort_values('Average Tool Calls', ascending=True)

print("Average Tool Calls Per Agent Type:")
print("=" * 80)

# Display with agent symbols
agent_tool_stats_display = agent_tool_stats.copy()
agent_tool_stats_display.index = [get_agent_label(agent) for agent in agent_tool_stats_display.index]
display(agent_tool_stats_display)

# Create a bar chart
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 6), dpi=100)

agents = agent_tool_stats.index
avg_tools = agent_tool_stats['Average Tool Calls']
std_tools = agent_tool_stats['Std Dev']

# Create bar chart with error bars
bars = ax.bar(range(len(agents)), avg_tools, yerr=std_tools, 
              capsize=5, alpha=0.8, edgecolor='black', linewidth=1.2)

# Use Color_Blind palette
raw_cmap = load_cmap("Color_Blind")
max_colors = len(agents)
sample_points = np.linspace(0, 1, max_colors)
palette_rgba = raw_cmap(sample_points)
palette_list = [(r, g, b) for r, g, b, _ in palette_rgba]

# Color bars
for bar, color in zip(bars, palette_list):
    bar.set_color(color)

# Customize plot
ax.set_xlabel('Agent configurations', fontsize=13, fontweight='bold', color='#2c3e50')
ax.set_ylabel('Average Tool Calls', fontsize=13, fontweight='bold', color='#2c3e50')
ax.set_title('Average Number of Tool Calls by Agent Type', 
             fontsize=15, fontweight='bold', color='#2c3e50', pad=20)

ax.set_xticks(range(len(agents)))
agent_labels = [get_agent_label(agent) for agent in agents]
ax.set_xticklabels(agent_labels, fontsize=11, color='#34495e', rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), fontsize=11, color='#34495e')

# Add grid
ax.grid(axis='y', alpha=0.3, linestyle='--', linewidth=0.7, color='#bdc3c7')
ax.set_axisbelow(True)

# Customize spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_color('#34495e')
ax.spines['bottom'].set_color('#34495e')

# Add value labels on top of bars
for i, (avg, std) in enumerate(zip(avg_tools, std_tools)):
    ax.text(i, avg + std + 0.3, f'{avg:.1f}', 
            ha='center', va='bottom', fontsize=10, fontweight='bold', color='#2c3e50')

plt.tight_layout()

# Save plot
try:
    png_file = os.path.join(PLOTS_DIR, "avg_tool_calls_by_agent.png")
    plt.savefig(png_file, dpi=200, bbox_inches='tight', facecolor='white')
    print(f"\n✓ Saved average tool calls PNG to {png_file}")
except Exception as e:
    print(f"\n⚠ Could not save PNG. Error: {e}")

try:
    pdf_file = os.path.join(PLOTS_DIR, "avg_tool_calls_by_agent.pdf")
    plt.savefig(pdf_file, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved average tool calls PDF to {pdf_file}")
except Exception as e:
    print(f"⚠ Could not save PDF. Error: {e}")

plt.show()


## Total Tool Distribution Pie Chart

In [None]:
# Count total occurrences of each tool across all steps
tool_counts = {}

for _, row in df.iterrows():
    for step in range(1, max_depth):
        col_name = f"Tool {step}"
        tool = row[col_name]
        
        if pd.notna(tool) and tool != "End":
            # Exclude submit_final_diagnosis from counts
            if tool != "submit_final_diagnosis":
                tool_counts[tool] = tool_counts.get(tool, 0) + 1
        else:
            if pd.notna(tool) and tool == "End":
                continue
            else:
                break

# Prepare data for pie chart
tools_list = list(tool_counts.keys())
counts_list = list(tool_counts.values())
total = sum(counts_list)
percentages = [(count / total) * 100 for count in counts_list]

# Sort by count (descending)
sorted_data = sorted(zip(tools_list, counts_list, percentages), key=lambda x: x[1], reverse=True)
tools_list, counts_list, percentages = zip(*sorted_data)

# Create color palette from the same cmap without repeats
raw_cmap = load_cmap("Color_Blind")
sample_points = np.linspace(0, 1, len(tools_list))
palette_rgba = raw_cmap(sample_points)
palette_colors = [(r, g, b) for r, g, b, _ in palette_rgba]

# Custom autopct function to hide small percentages
def autopct_format(pct):
    return f'{pct:.1f}%' if pct >= 3.0 else ''  # Only show percentages >= 3%

# Create pie chart - optimized for two-column paper layout
fig, ax = plt.subplots(figsize=(10, 6), dpi=100)
wedges, texts, autotexts = ax.pie(
    counts_list,
    labels=None,
    colors=palette_colors,
    autopct=autopct_format,
    startangle=90,
    textprops=dict(fontsize=11, fontweight='bold', color='white'),
    wedgeprops=dict(edgecolor='white', linewidth=2.5, antialiased=True),
    pctdistance=0.75
 )
ax.set_aspect('equal')
ax.margins(0, 0)

# Position the pie to make room for the legend below
ax.set_position([0.08, 0.30, 0.84, 0.64])

# Customize percentage text
for autotext in autotexts:
    autotext.set_color('white')
    autotext.set_fontweight('bold')
    autotext.set_fontsize(11)

# Create legend with tool names and percentages (only percentage in bold using LaTeX)
legend_labels = [f"{tool}: $\\mathbf{{{pct:.1f}\\%}}$" for tool, pct in zip(tools_list, percentages)]
legend = fig.legend(
    legend_labels,
    title='Tools',
    fontsize=9,
    title_fontsize=11,
    loc='lower center',
    bbox_to_anchor=(0.5, 0.10),
    ncol=3,
    framealpha=0.97,
    fancybox=True,
    shadow=True,
    edgecolor='#bdc3c7',
    frameon=True,
    handlelength=2.0,
    handleheight=1.5,
    markerscale=1.2
 )

# Make legend title bold
legend.get_title().set_fontweight('bold')

# Title (closer to the pie)
fig.suptitle('Total Distribution of Tool Calls in RCA Tasks',
             fontsize=16, fontweight='bold', color='black', y=0.95)

# Save as PNG and PDF
try:
    png_file = os.path.join(PLOTS_DIR, "tool_distribution_pie_chart_two_column.png")
    fig.savefig(png_file, dpi=200, bbox_inches='tight', pad_inches=2, facecolor='white')
    print(f"✓ Saved two-column pie chart PNG to {png_file}")
except Exception as e:
    print(f"⚠ Could not save PNG. Error: {e}")

try:
    pdf_file = os.path.join(PLOTS_DIR, "tool_distribution_pie_chart_two_column.pdf")
    fig.savefig(pdf_file, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved two-column pie chart PDF to {pdf_file}")
except Exception as e:
    print(f"⚠ Could not save PDF. Error: {e}")

plt.show()

print("\nTool Distribution Summary (excluding submit_final_diagnosis):")
print("=" * 50)
for tool, count, pct in zip(tools_list, counts_list, percentages):
    print(f"{tool}: {int(count)} calls ({pct:.1f}%)")