# SMAC Counterfactual Analysis Demo

This notebook demonstrates how to use the Counterfactual Reasoning framework with the StarCraft Multi-Agent Challenge (SMAC) environment.

In [1]:
import os
import sys
import numpy as np
from smac.env import StarCraft2Env

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), 'src'))

from counterfactual_rl.environments.smac import CentralizedSmacWrapper
from counterfactual_rl.analysis.counterfactual import CounterfactualAnalyzer
from counterfactual_rl.environments import registry

# Set SC2 Path
os.environ['SC2PATH'] = r'C:\Program Files (x86)\StarCraft II'
print(f"SC2PATH: {os.environ['SC2PATH']}")

SC2PATH: C:\Program Files (x86)\StarCraft II


## 1. Setup Environment
We use the `CentralizedSmacWrapper` with global state (like a human player).

In [2]:
map_name = "3m"
env = StarCraft2Env(map_name=map_name, window_size_x=1, window_size_y=1)
wrapped_env = CentralizedSmacWrapper(env, use_state=True)  # Use global state

print(f"Environment initialized: {map_name}")
print(f"Joint Action Space: {wrapped_env.joint_action_space_size}")

Environment initialized: 3m
Joint Action Space: 729


## 2. Define a Random Agent
Since we don't have a trained model yet, we'll use a dummy agent that takes random actions.

In [3]:
class RandomAgent:
    def __init__(self, action_space_size):
        self.action_space_size = action_space_size
        
    def predict(self, obs, deterministic=True):
        # Return random action and None (for state)
        return np.random.randint(0, self.action_space_size), None

model = RandomAgent(wrapped_env.joint_action_space_size)

## 3. Run Counterfactual Analysis
We initialize the analyzer with the SMAC state manager.

**Note**: This will be slow due to the replay strategy (resetting and replaying actions for each counterfactual).

In [4]:
# Get State Manager from Registry
state_manager = registry.get_state_manager("SMAC-3m")

# Initialize analyzer (no env_config needed!)
analyzer = CounterfactualAnalyzer(
    model=model,
    env=wrapped_env,
    state_manager=state_manager,
    horizon=10,  # Shorter horizon for speed
    n_rollouts=1  # Fewer rollouts for speed
)

print("Analyzer initialized.")

Analyzer initialized.


In [6]:
# Run analysis on a single episode
print("Running analysis... (this may take a while due to replays)")
records = analyzer.evaluate_episode(max_steps=20, verbose=True)  # Limit steps for speed

print(f"Analysis complete! Analyzed {len(records)} state-action pairs.")

Running analysis... (this may take a while due to replays)




ConnectError: Failed to connect to the SC2 websocket. Is it up?

In [None]:
# Convert to DataFrame for analysis
import pandas as pd

data = []
for record in records:
    data.append({
        'step': record.step,
        'action': record.action,
        'consequence_score': record.consequence_score,
        'reward': record.reward
    })

df = pd.DataFrame(data)
df.head()

In [None]:
# Show high consequence states
consequential_states = df[df['consequence_score'] > 0.1]
print(f"Found {len(consequential_states)} consequential states (score > 0.1).")
consequential_states

In [None]:
# Plot consequence scores over time
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(df['step'], df['consequence_score'], marker='o')
plt.xlabel('Step')
plt.ylabel('Consequence Score')
plt.title('Consequence Scores Over Episode')
plt.grid(True)
plt.show()