In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load trajectory
df = pd.read_csv("results/detailed_trajectory.csv")

# ============================================
# Question 1: Is the agent closing switches?
# ============================================

switch_actions = df[df['action_type'].str.contains('switch')]
print(f"Total switch actions: {len(switch_actions)}")
print(f"  Opens:  {len(switch_actions[switch_actions['action_type'] == 'switch_open'])}")
print(f"  Closes: {len(switch_actions[switch_actions['action_type'] == 'switch_close'])}")

# Timeline of switch operations
fig, ax = plt.subplots(figsize=(12, 6))
for ep in df['episode'].unique():
    ep_data = df[df['episode'] == ep]
    opens = ep_data[ep_data['action_type'] == 'switch_open']
    closes = ep_data[ep_data['action_type'] == 'switch_close']
    
    ax.scatter(opens['time_s'], [ep]*len(opens), marker='x', color='red', s=100, label='Open' if ep==1 else '')
    ax.scatter(closes['time_s'], [ep]*len(closes), marker='o', color='green', s=100, label='Close' if ep==1 else '')

ax.set_xlabel('Time (s)')
ax.set_ylabel('Episode')
ax.set_title('Switch Operations Timeline: Opens (X) vs Closes (O)')
ax.legend()
plt.tight_layout()
plt.savefig('figures/switch_timeline.png', dpi=300)

# ============================================
# Question 2: Does closing restore topology?
# ============================================

for ep in df['episode'].unique():
    ep_data = df[df['episode'] == ep]
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    
    # Top: Live lines
    ax = axes[0]
    ax.plot(ep_data['time_s'], ep_data['live_lines'], linewidth=2)
    ax.axhline(37, color='gray', linestyle='--', alpha=0.5, label='Full topology (37 lines)')
    ax.set_ylabel('Live Lines')
    ax.set_title(f'Episode {ep}: Topology Recovery')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Mark switch actions
    switch_acts = ep_data[ep_data['action_type'].str.contains('switch')]
    for _, row in switch_acts.iterrows():
        color = 'green' if 'close' in row['action_type'] else 'red'
        ax.axvline(row['time_s'], color=color, alpha=0.3, linestyle=':')
    
    # Bottom: Load served
    ax = axes[1]
    ax.plot(ep_data['time_s'], ep_data['served_critical_frac']*100, label='Critical', linewidth=2)
    ax.plot(ep_data['time_s'], ep_data['served_total_frac']*100, label='Total', linewidth=2)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Load Served (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'figures/episode_{ep}_topology_recovery.png', dpi=300)
    plt.close()

# ============================================
# Question 3: What's the cascade impact?
# ============================================

# Find episodes with cascades
cascade_episodes = df[df['cascade_tripped_lines'].str.len() > 0]['episode'].unique()

for ep in cascade_episodes[:3]:  # First 3 cascade episodes
    ep_data = df[df['episode'] == ep]
    
    # Find cascade step
    cascade_step = ep_data[ep_data['cascade_tripped_lines'].str.len() > 0].iloc[0]
    cascade_time = cascade_step['time_s']
    
    print(f"\nEpisode {ep} - Cascade at t={cascade_time}s")
    print(f"  Lines before: {ep_data[ep_data['time_s'] < cascade_time]['live_lines'].iloc[-1]}")
    print(f"  Lines after:  {cascade_step['live_lines']}")
    print(f"  Tripped:      {cascade_step['cascade_tripped_lines']}")
    
    # Agent response
    post_cascade = ep_data[ep_data['time_s'] > cascade_time].head(10)
    agent_switches = post_cascade[post_cascade['action_type'].str.contains('switch')]
    
    print(f"  Agent actions (next 10 steps):")
    for _, row in agent_switches.iterrows():
        print(f"    t={row['time_s']:.0f}s: {row['action_type']} {row['action_detail']}")

# ============================================
# Question 4: Strategy summary
# ============================================

print("\n" + "="*60)
print("AGENT STRATEGY SUMMARY")
print("="*60)

# Action distribution
action_counts = df['action_type'].value_counts()
print("\nAction Distribution:")
for action, count in action_counts.items():
    pct = count / len(df) * 100
    print(f"  {action:20s}: {count:5d} ({pct:.1f}%)")

# Average recovery metrics
print("\nRecovery Performance:")
print(f"  Avg critical load maintained:  {df['served_critical_frac'].mean()*100:.1f}%")
print(f"  Avg total load served:         {df['served_total_frac'].mean()*100:.1f}%")
print(f"  Avg final live lines:          {df.groupby('episode')['live_lines'].last().mean():.1f}/37")

# Safety compliance
print("\nSafety Compliance:")
print(f"  Steps with voltage violations: {df[df['v_violations'] > 0].shape[0]} / {len(df)}")
print(f"  Steps with PF failure:         {df[~df['powerflow_success']].shape[0]} / {len(df)}")
print(f"  Actions masked by safety:      {df[df['action_masked']].shape[0]} / {len(df)}")

print("="*60)