# Notebook 09: Multi-Agent Credit Assignment

## Learning Objectives
- Implement exact and approximate Shapley values
- Localize errors in agent traces
- Assign per-agent training signals
- Implement AT-GRPO advantages
- Visualize credit distribution

## The Credit Assignment Problem

**Scenario:** Solver→Critic→Reviser pipeline gives wrong answer.

**Question:** Which agent is responsible?

**Shapley Value** gives agent $i$'s *average marginal contribution*:
$$\varphi_i(v) = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!\,(|N|-|S|-1)!}{|N|!} \left[v(S \cup \{i\}) - v(S)\right]$$

Properties: Efficiency (credits sum to outcome), Symmetry, Dummy axiom, Linearity.

In [None]:
# !pip install torch matplotlib

In [None]:
import sys
sys.path.insert(0, '..')
import torch
import matplotlib.pyplot as plt
print('Ready!')

## Step 1: Exact Shapley Values

In [None]:
from src.credit_assignment.shapley import exact_shapley_values, ShapleyCalculator

agents = ['solver', 'critic', 'reviser']

# Value function: solver is most important, critic adds value, reviser helps with correction
def coalition_value(S):
    if 'solver' not in S: return 0.0
    if len(S) == 1: return 0.4  # solver alone = 40% accuracy
    if 'critic' in S and 'reviser' in S: return 1.0  # full team
    if 'critic' in S: return 0.7  # solver + critic
    if 'reviser' in S: return 0.6  # solver + reviser
    return 0.4

shapley = exact_shapley_values(agents, coalition_value)
print('Exact Shapley Values:')
for a, v in shapley.items():
    print(f'  {a:10s}: {v:.4f}')
print(f'\nSum of Shapley values: {sum(shapley.values()):.4f}')
print(f'v(all agents) = {coalition_value(frozenset(agents)):.4f}')

## Step 2: Visualize Shapley Values

In [None]:
from src.evaluation.visualization import plot_agent_contributions
fig = plot_agent_contributions(shapley, title='Shapley Value Credit Attribution')
plt.show()

## Step 3: Error Localization

In [None]:
from src.credit_assignment.error_localization import ErrorLocalizer

# Scenario: solver makes error, critic catches it, reviser fixes it
trace = [
    {'agent_id': 'solver_0', 'role': 'solver',
     'content': 'Step 1: 45 + 18 = 63. The answer is: 63'},  # WRONG
    {'agent_id': 'critic_0', 'role': 'critic',
     'content': 'VERDICT: INCORRECT — should subtract not add'},
    {'agent_id': 'reviser_0', 'role': 'reviser',
     'content': 'Step 1: 45 - 18 = 27. The answer is: 27'},  # CORRECT
]

localizer = ErrorLocalizer(ground_truth=27.0)
report = localizer.get_report(trace, final_correct=True)
print('Error Localization Report:')
for k, v in report.items():
    print(f'  {k}: {v}')

## Step 4: AT-GRPO Advantages

In [None]:
from src.credit_assignment.at_grpo import ATGRPOTrainer, ATGRPOConfig

agent_rewards = {'solver': 0.3, 'critic': 0.8, 'reviser': 1.0}
turn_rewards  = {
    'solver':  [0.1, 0.2, 0.3],
    'critic':  [0.5, 0.9, 0.8],
    'reviser': [0.7, 0.9, 1.0],
}
trainer = ATGRPOTrainer(agents=[], config=ATGRPOConfig(agent_weight=0.5))
advantages = trainer.compute_combined_advantages(agent_rewards, turn_rewards)
print('AT-GRPO Combined Advantages:')
for agent, adv in advantages.items():
    print(f'  {agent}: {[round(a, 3) for a in adv.tolist()]}')

## Step 5: Credit Heatmap

In [None]:
from src.evaluation.visualization import plot_credit_heatmap
matrix = [[0.8, 0.6, 0.9], [-0.5, 0.4, 0.5], [0.2, 0.7, 0.9]]
fig = plot_credit_heatmap(
    matrix,
    agent_labels=['solver', 'critic', 'reviser'],
    turn_labels=['Turn 1', 'Turn 2', 'Turn 3'],
    title='Agent x Turn Credit Attribution Heatmap'
)
plt.show()

---

## Exercises

1. **Monte Carlo Shapley:** Compare exact vs approximate (n=10, 50, 200 samples). When are they equivalent?
2. **Coalition function:** Design a value function where critic has negative Shapley value. When would that happen?
3. **Error cascade:** Create a trace where solver makes subtle error not caught by critic. What credit does critic receive?
4. **AT-GRPO alpha:** Try alpha=0 (turn-only) vs 0.5 vs 1.0 (agent-only). Which stabilizes faster?
5. **Extension:** Implement credit assignment on all 20 GSM8K problems and visualize distribution