# SMAC Counterfactual Analysis Demo

This notebook demonstrates running counterfactual analysis on SMAC with a random policy.

In [1]:
import sys
import os

# Get the absolute path to the agents folder
smac_agents_path = os.path.abspath('../../playing-with-smac/smac/agents')
ppo_path = os.path.join(smac_agents_path, 'PPO_one_action')

sys.path.insert(0, smac_agents_path)  # For utils
sys.path.insert(0, ppo_path)          # For random_policy

print(f"Added: {smac_agents_path}")
print(f"Added: {ppo_path}")

Added: c:\Users\manchadoa\OneDrive - Milwaukee School of Engineering\ADRIAN LAPTOP SCHOOL\SEMESTERS\FALL 2025\UR Conterfactual Reasoning\playing-with-smac\smac\agents
Added: c:\Users\manchadoa\OneDrive - Milwaukee School of Engineering\ADRIAN LAPTOP SCHOOL\SEMESTERS\FALL 2025\UR Conterfactual Reasoning\playing-with-smac\smac\agents\PPO_one_action


## 1. Setup SMAC Environment

In [2]:
from smac.env import StarCraft2Env
from counterfactual_rl.environments.smac import CentralizedSmacWrapper, SmacStateManager

# Create SMAC environment (3m = 3 Marines vs 3 Marines)
smac_env = StarCraft2Env(map_name="3m")

# Wrap it for centralized control
env = CentralizedSmacWrapper(smac_env, use_state=True)

# Get environment info
print(f"Agents: {env.n_agents}")
print(f"Actions per agent: {env.n_actions_per_agent}")
print(f"Action space: {env.action_space}")
print(f"Observation space: {env.observation_space}")

Agents: 3
Actions per agent: 9
Action space: MultiDiscrete([9 9 9])
Observation space: Box(-inf, inf, (48,), float32)


## 2. Create Random Policy

In [3]:
from random_policy import RandomPolicy

# Create random policy (uses raw SMAC env for action masking)
policy = RandomPolicy(env=smac_env)

print(f"Policy n_agents: {policy.n_agents}")
print(f"Policy n_actions: {policy.n_actions}")

Policy n_agents: 3
Policy n_actions: 9


## 3. Setup Counterfactual Analyzer

In [4]:
from counterfactual_rl.analysis import MultiDiscreteCounterfactualAnalyzer
from utils import get_valid_actions

# Create analyzer
analyzer = MultiDiscreteCounterfactualAnalyzer(
    model=policy,
    env=env,
    state_manager=SmacStateManager,
    get_valid_actions_fn=lambda: get_valid_actions(smac_env),
    get_action_probs_fn=None,  # Uniform probability for random policy
    n_agents=policy.n_agents,
    n_actions=policy.n_actions,
    horizon=10,
    n_rollouts=5,  # Keep low for demo
    top_k=10,
    deterministic=False
)

print("Analyzer created!")

Analyzer created!


## 4. Run Counterfactual Analysis

In [5]:
import numpy as np

In [6]:
avail_actions_all = np.array([[0, 1, 1, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0, 0, 0]])

print(avail_actions_all)
available_actions = [np.where(agent_avail_actions == 1)[0].tolist() for agent_avail_actions in avail_actions_all]

print(available_actions)


[[0 1 1 1 1 1 0 0 0]
 [0 1 1 1 1 1 0 0 0]
 [0 1 1 1 1 1 0 0 0]]
[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]


In [7]:
# Evaluate a single episode
print("Running counterfactual analysis...")
records = analyzer.evaluate_episode(max_steps=50, verbose=True)

print(f"\nCollected {len(records)} records")

Running counterfactual analysis...
  Starting episode (max_steps=50)

  Step 0/50
    Cloning state...
    Getting policy action...
    Action selected: (np.int64(3), np.int64(5), np.int64(2))
    Running 10 counterfactual rollouts x 5 samples...
      Getting valid actions...
      Valid actions per agent: [5, 5, 5]
      Running beam search for top-10 actions...
      Beam search returned 10 joint actions
        Action 1/10: (1, 1, 1) -> mean return: 0.208
        Action 2/10: (1, 1, 2)

KeyboardInterrupt: 

## 5. Analyze Results

In [None]:
# Show consequence scores
if records:
    scores = [r.consequence_score for r in records]
    print(f"Mean consequence score: {np.mean(scores):.4f}")
    print(f"Max consequence score: {np.max(scores):.4f}")
    print(f"Min consequence score: {np.min(scores):.4f}")
    
    # Show most consequential step
    most_consequential = max(records, key=lambda r: r.consequence_score)
    print(f"\nMost consequential step:")
    print(f"  Action: {most_consequential.action}")
    print(f"  Score: {most_consequential.consequence_score:.4f}")

In [None]:
# Cleanup
smac_env.close()
print("Environment closed.")