# VERITAS Trace Visualization

This notebook creates visualizations of VERITAS decision traces, illustrating the DAG structure described in **Section 2** of the paper.

We create:
- **DAG Visualizations**: Node-edge graphs showing reasoning flow
- **Cryptographic Binding Trees**: Merkle-DAG structure
- **Timeline Views**: Temporal progression of reasoning
- **Type Distribution**: Analysis of node types
- **Interactive Visualizations**: Explorable trace graphs

In [None]:
import sys
sys.path.insert(0, '../code')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from datetime import datetime, timedelta
from typing import Dict, List

from core import DecisionTrace, NodeType, AgentManifest
from crypto import SignatureScheme, TraceVerifier
from compression import TraceCompressor

# Set up plotting
sns.set_style('white')
plt.rcParams['figure.figsize'] = (14, 10)
plt.rcParams['font.size'] = 9

print("✓ Imports successful")

## 1. Create Example Trace

Create a realistic multi-step decision trace for visualization.

In [None]:
def create_medical_diagnosis_trace() -> DecisionTrace:
    """
    Create a medical diagnosis trace with branching logic.
    """
    trace = DecisionTrace()
    trace.agent_manifest = AgentManifest(
        agent_did="did:agent:medical:diagnostic-v1",
        model_version="gpt-4-turbo",
        framework="custom"
    )
    
    # Initial observation
    n1 = trace.add_reasoning_step(
        content="Patient presents with severe headache, fever 102°F, and neck stiffness.",
        node_type=NodeType.OBSERVATION,
        confidence=0.95
    )
    
    # Memory access - retrieve similar cases
    n2 = trace.add_reasoning_step(
        content="Accessing patient history: no prior meningitis, recent URI 2 weeks ago.",
        node_type=NodeType.MEMORY_ACCESS,
        parent_ids=[n1.node_id],
        confidence=0.90
    )
    
    # Physical exam
    n3 = trace.add_reasoning_step(
        content="Physical exam: Kernig's sign positive, Brudzinski's sign positive.",
        node_type=NodeType.OBSERVATION,
        parent_ids=[n1.node_id],
        confidence=0.92
    )
    
    # Differential diagnosis
    n4 = trace.add_reasoning_step(
        content="Differential: bacterial meningitis (60%), viral meningitis (30%), other (10%).",
        node_type=NodeType.REASONING,
        parent_ids=[n1.node_id, n2.node_id, n3.node_id],
        confidence=0.85
    )
    
    # Decision to order tests
    n5 = trace.add_reasoning_step(
        content="Decision: Order stat lumbar puncture and blood cultures.",
        node_type=NodeType.DECISION,
        parent_ids=[n4.node_id],
        confidence=0.88
    )
    
    # Tool call - LP
    n6 = trace.add_reasoning_step(
        content="Tool: order_procedure(type='lumbar_puncture', priority='stat')",
        node_type=NodeType.TOOL_CALL,
        parent_ids=[n5.node_id],
        confidence=0.95
    )
    
    # Tool call - Blood cultures
    n7 = trace.add_reasoning_step(
        content="Tool: order_lab(test='blood_culture', sets=2, priority='stat')",
        node_type=NodeType.TOOL_CALL,
        parent_ids=[n5.node_id],
        confidence=0.95
    )
    
    # Observation - CSF results
    n8 = trace.add_reasoning_step(
        content="CSF: cloudy, WBC 2500 (PMN 90%), protein 180, glucose 25. Gram stain: GPC in pairs.",
        node_type=NodeType.OBSERVATION,
        parent_ids=[n6.node_id],
        confidence=0.98
    )
    
    # Reasoning - diagnosis
    n9 = trace.add_reasoning_step(
        content="CSF findings diagnostic of bacterial meningitis. Gram stain suggests S. pneumoniae.",
        node_type=NodeType.REASONING,
        parent_ids=[n8.node_id, n4.node_id],
        confidence=0.95
    )
    
    # Decision - treatment
    n10 = trace.add_reasoning_step(
        content="Decision: Initiate empiric antibiotics (ceftriaxone + vancomycin) immediately.",
        node_type=NodeType.DECISION,
        parent_ids=[n9.node_id],
        confidence=0.92
    )
    
    # Tool call - medication
    n11 = trace.add_reasoning_step(
        content="Tool: order_medication(drug='ceftriaxone', dose='2g IV q12h', start='now')",
        node_type=NodeType.TOOL_CALL,
        parent_ids=[n10.node_id],
        confidence=0.95
    )
    
    n12 = trace.add_reasoning_step(
        content="Tool: order_medication(drug='vancomycin', dose='15mg/kg IV q8-12h', start='now')",
        node_type=NodeType.TOOL_CALL,
        parent_ids=[n10.node_id],
        confidence=0.95
    )
    
    # Final decision - admission
    n13 = trace.add_reasoning_step(
        content="Decision: Admit to ICU for close monitoring and continued treatment.",
        node_type=NodeType.DECISION,
        parent_ids=[n10.node_id],
        confidence=0.90
    )
    
    return trace

# Create trace
trace = create_medical_diagnosis_trace()
print(f"✓ Created trace with {len(trace.trace_graph.nodes)} nodes and {len(trace.trace_graph.edges)} edges")
print(f"  Max depth: {trace.get_trace_depth()}")

## 2. DAG Visualization

Visualize the trace as a directed acyclic graph.

In [None]:
def visualize_trace_dag(trace: DecisionTrace, save_path: str = None):
    """
    Create a network graph visualization of the trace DAG.
    """
    # Create NetworkX graph
    G = nx.DiGraph()
    
    # Add nodes with attributes
    for node in trace.trace_graph.nodes:
        G.add_node(
            node.node_id,
            label=node.node_id[:8],
            type=node.node_type.value,
            confidence=node.metadata.confidence if node.metadata else 0.0
        )
    
    # Add edges
    for edge in trace.trace_graph.edges:
        G.add_edge(edge.from_node, edge.to_node, weight=edge.weight)
    
    # Color map for node types
    type_colors = {
        'REASONING': '#3498db',      # Blue
        'TOOL_CALL': '#e74c3c',      # Red
        'OBSERVATION': '#2ecc71',    # Green
        'DECISION': '#f39c12',       # Orange
        'MEMORY_ACCESS': '#9b59b6'  # Purple
    }
    
    # Get node colors and sizes
    node_colors = [type_colors.get(G.nodes[node]['type'], '#95a5a6') for node in G.nodes()]
    node_sizes = [G.nodes[node]['confidence'] * 1000 + 200 for node in G.nodes()]
    
    # Layout - hierarchical
    pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
    
    # Try to use hierarchical layout if possible
    try:
        # Group by depth for hierarchical layout
        depths = {}
        for node in nx.topological_sort(G):
            if not list(G.predecessors(node)):
                depths[node] = 0
            else:
                depths[node] = max(depths[p] for p in G.predecessors(node)) + 1
        
        # Create hierarchical positions
        max_depth = max(depths.values())
        depth_counts = {d: 0 for d in range(max_depth + 1)}
        depth_nodes = {d: [] for d in range(max_depth + 1)}
        
        for node, depth in depths.items():
            depth_nodes[depth].append(node)
        
        pos = {}
        for depth, nodes in depth_nodes.items():
            n_nodes = len(nodes)
            for i, node in enumerate(nodes):
                x = (i - n_nodes / 2) * 2
                y = -depth * 2
                pos[node] = (x, y)
    except:
        pass  # Fall back to spring layout
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(16, 12))
    
    # Draw edges
    nx.draw_networkx_edges(
        G, pos, ax=ax,
        edge_color='gray',
        arrows=True,
        arrowsize=20,
        arrowstyle='->',
        width=2,
        alpha=0.6,
        connectionstyle='arc3,rad=0.1'
    )
    
    # Draw nodes
    nx.draw_networkx_nodes(
        G, pos, ax=ax,
        node_color=node_colors,
        node_size=node_sizes,
        alpha=0.9,
        edgecolors='black',
        linewidths=2
    )
    
    # Draw labels
    labels = {node: G.nodes[node]['label'] for node in G.nodes()}
    nx.draw_networkx_labels(
        G, pos, labels, ax=ax,
        font_size=8,
        font_weight='bold',
        font_color='white'
    )
    
    # Create legend
    legend_elements = [
        plt.scatter([], [], c=color, s=200, label=node_type, edgecolors='black', linewidths=2)
        for node_type, color in type_colors.items()
    ]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=10, framealpha=0.9)
    
    ax.set_title('VERITAS Decision Trace DAG\nMedical Diagnosis Example', 
                 fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    ax.margins(0.1)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"✓ Saved to {save_path}")
    
    plt.show()
    
    return G

# Visualize
G = visualize_trace_dag(trace, 'trace_dag_visualization.png')

## 3. Cryptographic Binding Visualization

Visualize the Merkle-DAG structure with hash bindings.

In [None]:
# Finalize trace to compute bindings
compressor = TraceCompressor()
compressor.compress_trace(trace)

private_key, public_key = SignatureScheme.generate_keypair()
TraceVerifier.finalize_trace(trace, private_key)

print(f"✓ Trace finalized")
print(f"  Root hash: {trace.root_hash[:40]}...")
print(f"  Signature: {trace.agent_manifest.signature[:40]}...")

# Visualize with hash information
fig, ax = plt.subplots(1, 1, figsize=(16, 10))

# Create positions (reuse from previous)
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

# Draw base graph
nx.draw_networkx_edges(G, pos, ax=ax, edge_color='gray', arrows=True, 
                       arrowsize=15, width=2, alpha=0.4)

# Color nodes by whether they have binding
node_colors = ['#2ecc71' if trace.trace_graph.get_node(node).cryptographic_binding else '#e74c3c' 
               for node in G.nodes()]

nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, 
                       node_size=800, alpha=0.9, edgecolors='black', linewidths=2)

# Add hash labels (truncated)
hash_labels = {}
for node in G.nodes():
    trace_node = trace.trace_graph.get_node(node)
    if trace_node and trace_node.cryptographic_binding:
        # Extract hash part after colon
        hash_part = trace_node.cryptographic_binding.split(':')[-1]
        hash_labels[node] = hash_part[:6] + '...'
    else:
        hash_labels[node] = 'No bind'

nx.draw_networkx_labels(G, pos, hash_labels, ax=ax, font_size=7, font_family='monospace')

ax.set_title('VERITAS Merkle-DAG Structure\nCryptographic Bindings', 
             fontsize=16, fontweight='bold', pad=20)
ax.text(0.02, 0.98, f'Root Hash: {trace.root_hash[:60]}...', 
        transform=ax.transAxes, fontsize=10, verticalalignment='top',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
        fontfamily='monospace')
ax.axis('off')

plt.tight_layout()
plt.savefig('trace_merkle_dag.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

print("✓ Merkle-DAG visualization saved")

## 4. Node Type Distribution

Analyze the distribution of reasoning step types.

In [None]:
# Get statistics
stats = trace.get_trace_statistics()

# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Node type distribution (pie chart)
type_counts = stats['node_types']
type_counts_filtered = {k: v for k, v in type_counts.items() if v > 0}

colors_pie = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6']
axes[0, 0].pie(type_counts_filtered.values(), labels=type_counts_filtered.keys(), 
               autopct='%1.1f%%', colors=colors_pie, startangle=90)
axes[0, 0].set_title('Node Type Distribution', fontsize=12, fontweight='bold')

# Plot 2: Node type counts (bar)
axes[0, 1].bar(type_counts_filtered.keys(), type_counts_filtered.values(), 
               color=colors_pie, alpha=0.7, edgecolor='black', linewidth=1.5)
axes[0, 1].set_title('Node Type Counts', fontsize=12, fontweight='bold')
axes[0, 1].set_ylabel('Count')
axes[0, 1].set_xlabel('Node Type')
axes[0, 1].tick_params(axis='x', rotation=45)
axes[0, 1].grid(True, alpha=0.3, axis='y')

# Plot 3: Trace statistics
stat_labels = ['Total\nNodes', 'Total\nEdges', 'Root\nNodes', 'Terminal\nNodes', 'Max\nDepth']
stat_values = [
    stats['total_nodes'],
    stats['total_edges'],
    stats['root_nodes'],
    stats['terminal_nodes'],
    stats['max_depth']
]

axes[1, 0].bar(stat_labels, stat_values, color='steelblue', alpha=0.7, 
               edgecolor='black', linewidth=1.5)
axes[1, 0].set_title('Trace Statistics', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Count')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# Plot 4: Confidence distribution
confidences = [node.metadata.confidence for node in trace.trace_graph.nodes if node.metadata]
axes[1, 1].hist(confidences, bins=15, color='green', alpha=0.7, edgecolor='black')
axes[1, 1].axvline(np.mean(confidences), color='red', linestyle='--', 
                   linewidth=2, label=f'Mean: {np.mean(confidences):.3f}')
axes[1, 1].set_title('Confidence Score Distribution', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Confidence')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('trace_statistics.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Statistics visualization saved")

## 5. Detailed Trace Flow Diagram

Create a detailed flow diagram with node content.

In [None]:
def create_flow_diagram(trace: DecisionTrace, max_content_len: int = 50):
    """
    Create a detailed flow diagram with node content.
    """
    fig, ax = plt.subplots(1, 1, figsize=(18, 14))
    
    # Create graph
    G = nx.DiGraph()
    for node in trace.trace_graph.nodes:
        content_preview = node.full_content[:max_content_len] + '...' if node.full_content and len(node.full_content) > max_content_len else (node.full_content or '')
        G.add_node(
            node.node_id,
            type=node.node_type.value,
            content=content_preview,
            confidence=node.metadata.confidence if node.metadata else 0.0
        )
    
    for edge in trace.trace_graph.edges:
        G.add_edge(edge.from_node, edge.to_node)
    
    # Hierarchical layout
    try:
        depths = {}
        for node in nx.topological_sort(G):
            if not list(G.predecessors(node)):
                depths[node] = 0
            else:
                depths[node] = max(depths[p] for p in G.predecessors(node)) + 1
        
        depth_nodes = {}
        for node, depth in depths.items():
            if depth not in depth_nodes:
                depth_nodes[depth] = []
            depth_nodes[depth].append(node)
        
        pos = {}
        for depth, nodes in depth_nodes.items():
            n_nodes = len(nodes)
            for i, node in enumerate(nodes):
                x = (i - n_nodes / 2) * 3.5
                y = -depth * 2.5
                pos[node] = (x, y)
    except:
        pos = nx.spring_layout(G, k=3, iterations=50)
    
    # Type colors
    type_colors = {
        'REASONING': '#3498db',
        'TOOL_CALL': '#e74c3c',
        'OBSERVATION': '#2ecc71',
        'DECISION': '#f39c12',
        'MEMORY_ACCESS': '#9b59b6'
    }
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, ax=ax, edge_color='#34495e', 
                           arrows=True, arrowsize=25, width=2.5, alpha=0.6,
                           arrowstyle='->', connectionstyle='arc3,rad=0.15')
    
    # Draw nodes with content boxes
    for node, (x, y) in pos.items():
        node_data = G.nodes[node]
        node_type = node_data['type']
        content = node_data['content']
        confidence = node_data['confidence']
        
        # Draw node circle
        circle = plt.Circle((x, y), 0.3, 
                           color=type_colors.get(node_type, '#95a5a6'),
                           ec='black', linewidth=2, zorder=3)
        ax.add_patch(circle)
        
        # Add type label on node
        ax.text(x, y, node_type.split('_')[0][:4], 
               ha='center', va='center', fontsize=8, fontweight='bold',
               color='white', zorder=4)
        
        # Add content box below
        box_props = dict(boxstyle='round,pad=0.3', facecolor='white', 
                        edgecolor='black', linewidth=1, alpha=0.9)
        ax.text(x, y - 0.6, content, 
               ha='center', va='top', fontsize=7, wrap=True,
               bbox=box_props, zorder=2, style='italic')
        
        # Add confidence indicator
        conf_color = plt.cm.RdYlGn(confidence)
        ax.text(x + 0.25, y + 0.25, f'{confidence:.2f}', 
               fontsize=6, fontweight='bold',
               bbox=dict(boxstyle='circle', facecolor=conf_color, alpha=0.8),
               zorder=5)
    
    ax.set_title('VERITAS Detailed Trace Flow\nMedical Diagnosis Decision Process', 
                 fontsize=16, fontweight='bold', pad=20)
    
    # Legend
    legend_elements = [
        plt.scatter([], [], c=color, s=300, label=node_type, edgecolors='black', linewidths=2)
        for node_type, color in type_colors.items()
    ]
    ax.legend(handles=legend_elements, loc='upper left', fontsize=9, framealpha=0.95)
    
    ax.axis('off')
    ax.set_xlim(ax.get_xlim()[0] - 2, ax.get_xlim()[1] + 2)
    ax.set_ylim(ax.get_ylim()[0] - 2, ax.get_ylim()[1] + 2)
    
    plt.tight_layout()
    plt.savefig('trace_detailed_flow.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print("✓ Detailed flow diagram saved")

create_flow_diagram(trace, max_content_len=40)

## 6. Summary

Print trace summary.

In [None]:
summary = f"""
{'='*80}
VERITAS TRACE VISUALIZATION SUMMARY
{'='*80}

Trace Information:
  ID: {trace.trace_id}
  Agent: {trace.agent_manifest.agent_did}
  Created: {trace.created_at}

Structure:
  Total nodes: {stats['total_nodes']}
  Total edges: {stats['total_edges']}
  Root nodes: {stats['root_nodes']}
  Terminal nodes: {stats['terminal_nodes']}
  Maximum depth: {stats['max_depth']}

Node Types:
"""

for node_type, count in stats['node_types'].items():
    if count > 0:
        pct = (count / stats['total_nodes']) * 100
        summary += f"  {node_type:15s}: {count:2d} ({pct:5.1f}%)\n"

summary += f"""
Cryptographic Properties:
  Root hash: {trace.root_hash[:60]}...
  Signature: {trace.agent_manifest.signature[:60]}...
  All nodes bound: {all(n.cryptographic_binding for n in trace.trace_graph.nodes)}

Confidence:
  Mean: {np.mean(confidences):.3f}
  Std:  {np.std(confidences):.3f}
  Min:  {np.min(confidences):.3f}
  Max:  {np.max(confidences):.3f}

Generated Visualizations:
  1. trace_dag_visualization.png - DAG structure
  2. trace_merkle_dag.png - Cryptographic bindings
  3. trace_statistics.png - Statistical analysis
  4. trace_detailed_flow.png - Detailed flow diagram

{'='*80}
"""

print(summary)

with open('visualization_summary.txt', 'w') as f:
    f.write(summary)

print("✓ Summary saved to 'visualization_summary.txt'")

## Conclusion

This notebook has created comprehensive visualizations of VERITAS decision traces:

1. **DAG Structure**: Showing reasoning flow and dependencies
2. **Merkle-DAG**: Illustrating cryptographic bindings
3. **Statistical Analysis**: Node type distribution and confidence
4. **Detailed Flow**: With node content and metadata

These visualizations demonstrate the trace structure described in Section 2 of the paper.