# BRAC Framework Demo

This notebook demonstrates the Byzantine-Resilient Agentic Consensus (BRAC) framework for multimodal Non-Hodgkin Lymphoma (NHL) subtyping.

## Key Features
1. **Trust-Bootstrapped Reliability** - Anchors trust to pathology as root-of-trust
2. **Riemannian Geometric Median** - Fisher-Rao geometry on probability simplex
3. **Byzantine Resilience** - Tolerates faulty/adversarial agents
4. **Conformal Prediction** - Distribution-free coverage guarantees
5. **Shapley Attribution** - Axiomatic explainability

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# BRAC imports
from brac.types import Modality, NHLSubtype, EvidenceQuality
from brac.orchestrator import BRACOrchestrator, BRACConfig
from brac.agents.mock_agent import MockAgentFactory, MockAgentConfig
from brac.attacks import ByzantineAttack, ByzantineScenario
from brac.consensus.fisher_rao import fisher_rao_distance
from brac.visualization import (
    plot_simplex, plot_shapley_bar, create_diagnostic_report
)

print("BRAC Framework loaded successfully!")

## 1. Basic Usage

Create an orchestrator and run diagnosis on a synthetic case.

In [None]:
# Create BRAC orchestrator with default config
config = BRACConfig(
    num_classes=9,
    alpha=0.05,
    compute_shapley=True,
)
orchestrator = BRACOrchestrator(config)

# Create mock agents
factory = MockAgentFactory(num_classes=9, seed=42)
agents = factory.create_all_agents()

# Generate a case (Follicular Lymphoma)
true_label = NHLSubtype.FL.value
agent_outputs = factory.generate_case(agents, true_label=true_label)

print(f"True diagnosis: {NHLSubtype.from_index(true_label).name}")
print(f"\nAgent beliefs (top-1 prediction):")
for m, output in agent_outputs.items():
    pred = NHLSubtype.from_index(output.predicted_class)
    print(f"  {m.value}: {pred.name} ({output.confidence:.1%})")

In [None]:
# Run BRAC consensus
result = orchestrator.run(agent_outputs, calibrated=False)

# Print summary
print(result.summary())

## 2. Visualization

Visualize agent beliefs on the probability simplex and Shapley attributions.

In [None]:
# Extract beliefs for visualization
beliefs = {m: out.belief for m, out in agent_outputs.items()}

# Create diagnostic report
fig = create_diagnostic_report(result, beliefs)
plt.show()

## 3. Byzantine Resilience Demo

Demonstrate that geometric median resists Byzantine attacks.

In [None]:
from brac.consensus.geometric_median import riemannian_weiszfeld
from brac.consensus.aggregators import weighted_average
from brac.types import ByzantineType
from brac.attacks import corrupt_agent_output

# Create a case with one Byzantine agent
true_label = NHLSubtype.DLBCL_GCB.value
honest_outputs = factory.generate_case(agents, true_label=true_label)

# Corrupt the radiology agent (Type II: confident but wrong)
byzantine_outputs = honest_outputs.copy()
byzantine_outputs[Modality.RADIOLOGY] = corrupt_agent_output(
    honest_outputs[Modality.RADIOLOGY],
    attack_type=ByzantineType.TYPE_II,
    K=9,
)

print(f"True diagnosis: {NHLSubtype.from_index(true_label).name}")
print(f"\nHonest radiology prediction: {NHLSubtype.from_index(honest_outputs[Modality.RADIOLOGY].predicted_class).name}")
print(f"Byzantine radiology prediction: {NHLSubtype.from_index(byzantine_outputs[Modality.RADIOLOGY].predicted_class).name}")

In [None]:
# Compare consensus methods
honest_beliefs = torch.stack([out.belief for out in honest_outputs.values()])
byzantine_beliefs = torch.stack([out.belief for out in byzantine_outputs.values()])
reliabilities = torch.ones(4) / 4

# Geometric median (our method)
gm_honest = riemannian_weiszfeld(honest_beliefs, reliabilities).consensus
gm_byzantine = riemannian_weiszfeld(byzantine_beliefs, reliabilities).consensus

# Weighted average (baseline)
wa_honest = weighted_average(honest_beliefs, reliabilities)
wa_byzantine = weighted_average(byzantine_beliefs, reliabilities)

print("Without Byzantine agent:")
print(f"  Geometric Median: {NHLSubtype.from_index(gm_honest.argmax().item()).name}")
print(f"  Weighted Average: {NHLSubtype.from_index(wa_honest.argmax().item()).name}")

print("\nWith Byzantine agent:")
print(f"  Geometric Median: {NHLSubtype.from_index(gm_byzantine.argmax().item()).name}")
print(f"  Weighted Average: {NHLSubtype.from_index(wa_byzantine.argmax().item()).name}")

print(f"\nTrue label: {NHLSubtype.from_index(true_label).name}")

## 4. Conformal Prediction Demo

Calibrate and evaluate conformal prediction sets.

In [None]:
# Generate calibration and test data
cal_data = []
test_data = []

for i in range(100):
    true_label = i % 9
    outputs = factory.generate_case(agents, true_label)
    cal_data.append((outputs, true_label))

for i in range(100):
    true_label = (i + 3) % 9
    outputs = factory.generate_case(agents, true_label)
    test_data.append((outputs, true_label))

# Calibrate
q_hat = orchestrator.calibrate(cal_data)
print(f"Calibrated threshold: q_hat = {q_hat:.4f}")

# Evaluate
metrics = orchestrator.evaluate(test_data)
print(f"\nTest metrics:")
print(f"  Accuracy: {metrics['accuracy']:.2%}")
print(f"  Coverage: {metrics['coverage']:.2%} (target: 95%)")
print(f"  Avg set size: {metrics['avg_set_size']:.2f}")
print(f"  Acceptance rate: {metrics['acceptance_rate']:.2%}")

## 5. Run with Calibrated Conformal Predictor

In [None]:
# Run on a new case with conformal prediction
true_label = NHLSubtype.MCL.value
outputs = factory.generate_case(agents, true_label)

result = orchestrator.run(outputs, calibrated=True)

print(f"True: {NHLSubtype.from_index(true_label).name}")
print(f"Prediction: {result.diagnosis.name}")
print(f"Prediction Set: {[s.name for s in result.prediction_set]}")
print(f"Confidence: {result.confidence:.1%}")
print(f"Decision: {'ACCEPT' if result.accepted else 'ESCALATE'}")

## 6. Fisher-Rao Distance Demo

Illustrate the Fisher-Rao geometry.

In [None]:
from brac.consensus.fisher_rao import (
    fisher_rao_distance, sqrt_embedding, sqrt_embedding_inv,
    verify_metric_axioms
)

# Create test distributions
K = 9
uniform = torch.ones(K) / K
one_hot = torch.zeros(K)
one_hot[0] = 1.0
peaked = torch.softmax(torch.tensor([5.0] + [0.1] * (K-1)), dim=0)

print("Fisher-Rao distances:")
print(f"  d(uniform, uniform) = {fisher_rao_distance(uniform, uniform).item():.4f}")
print(f"  d(uniform, one_hot) = {fisher_rao_distance(uniform, one_hot).item():.4f}")
print(f"  d(uniform, peaked)  = {fisher_rao_distance(uniform, peaked).item():.4f}")
print(f"  d(peaked, one_hot)  = {fisher_rao_distance(peaked, one_hot).item():.4f}")

# Verify metric axioms
p = torch.softmax(torch.randn(K), dim=0)
q = torch.softmax(torch.randn(K), dim=0)
r = torch.softmax(torch.randn(K), dim=0)

axioms = verify_metric_axioms(p, q, r)
print(f"\nMetric axioms verified:")
for axiom, passed in axioms.items():
    print(f"  {axiom}: {'✓' if passed else '✗'}")

## Summary

This demo showed:
1. How to create and use the BRAC orchestrator
2. Visualization of agent beliefs and attributions
3. Byzantine resilience of geometric median vs weighted average
4. Conformal prediction calibration and evaluation
5. Fisher-Rao geometry properties

For full experiments, run the scripts in `experiments/`.