In [None]:
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# -------------------------------
# 1. Basic configuration
# -------------------------------

print("TU Q105_A - Toy systemic crash warnings")
print("----------------------------------------")
print("This notebook simulates many tiny network contagion scenarios")
print("and computes a scalar tension observable T_warning for each")
print("early-warning scheme.\n")
print("All runs are fully offline. No API key is needed.\n")

RNG_SEED = 105
N_NODES = 20
N_CORE = 5
N_RUNS = 500

ALPHA_REDISPATCH = 0.8  # fraction of failed node load redistributed to neighbors
BASE_CAPACITY = 1.0
BASE_LOAD_MEAN = 0.6
BASE_LOAD_STD = 0.05
NEAR_THRESHOLD_RATIO = 0.8
CRASH_FRACTION_THRESHOLD = 0.4  # systemic crash if >= 40% of nodes fail

random.seed(RNG_SEED)
np.random.seed(RNG_SEED)

print("Configured world:")
print(f"- Nodes: {N_NODES} (core: {N_CORE}, periphery: {N_NODES - N_CORE})")
print("- Network: core-periphery with additional ring edges in the periphery")
print(f"- Crash definition: cascade size >= {CRASH_FRACTION_THRESHOLD * 100:.0f}% of nodes")
print(f"- Number of scenarios per scheme: {N_RUNS}")
print("")


# -------------------------------
# 2. Network construction
# -------------------------------

def build_core_periphery_network(n_nodes: int, n_core: int) -> np.ndarray:
    """
    Build a simple weighted adjacency matrix for a core-periphery network.

    Core nodes (0..n_core-1) form a dense subgraph.
    Periphery nodes connect to one core node and to neighbors in a ring.
    We use symmetric weights normalized per node.
    """
    adj = np.zeros((n_nodes, n_nodes), dtype=float)

    # Core: fully connected (no self loops)
    for i in range(n_core):
        for j in range(n_core):
            if i != j:
                adj[i, j] = 1.0

    # Periphery: each connects to one core node and ring neighbors
    for i in range(n_core, n_nodes):
        # connect to a core node (round-robin)
        core_target = i % n_core
        adj[i, core_target] = 1.0
        adj[core_target, i] = 1.0

        # ring neighbors in periphery
        if i < n_nodes - 1:
            adj[i, i + 1] = 1.0
            adj[i + 1, i] = 1.0
        else:
            # last node connects back to first periphery node
            adj[i, n_core] = 1.0
            adj[n_core, i] = 1.0

    # Normalize outgoing weights per node so that sum_j w_ij = 1 where degree > 0
    row_sums = adj.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    adj = adj / row_sums

    return adj


ADJ_MATRIX = build_core_periphery_network(N_NODES, N_CORE)
DEGREE = ADJ_MATRIX.sum(axis=1)  # should be 1.0 for nodes with neighbors
CORE_INDEX = np.arange(0, N_CORE)
PERIPHERY_INDEX = np.arange(N_CORE, N_NODES)

print("Network built.")
print(f"- Average degree (by weight normalisation): {DEGREE.mean():.2f}")
print("")


# -------------------------------
# 3. State initialisation helpers
# -------------------------------

@dataclass
class WorldState:
    capacities: np.ndarray  # shape (N_NODES,)
    base_loads: np.ndarray  # shape (N_NODES,)


def sample_world_state() -> WorldState:
    """
    Sample node capacities and baseline loads.

    Capacities are around BASE_CAPACITY with small heterogeneity.
    Baseline loads are below capacity on average.
    """
    capacities = np.random.normal(loc=BASE_CAPACITY, scale=0.05, size=N_NODES)
    capacities = np.clip(capacities, 0.8, 1.2)

    loads = np.random.normal(loc=BASE_LOAD_MEAN, scale=BASE_LOAD_STD, size=N_NODES)
    loads = np.clip(loads, 0.3, 0.9)

    return WorldState(capacities=capacities, base_loads=loads)


# -------------------------------
# 4. Contagion dynamics
# -------------------------------

def run_cascade(world: WorldState, shock_nodes: List[int], shock_size: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Run a cascade process on a copy of the world state.

    Returns:
        failed_mask: boolean array, True where node ever failed.
        final_loads: final load vector after cascade.
    """
    capacities = world.capacities.copy()
    loads = world.base_loads.copy()
    n = len(loads)

    # apply shock
    for node in shock_nodes:
        loads[node] += shock_size

    failed = np.zeros(n, dtype=bool)

    # iterative failure propagation
    while True:
        # nodes that newly fail in this step
        to_fail = (~failed) & (loads > capacities)
        if not np.any(to_fail):
            break

        # compute redistributed load from failures
        outgoing_load = loads[to_fail].copy()
        loads[to_fail] = 0.0
        failed[to_fail] = True

        if np.any(outgoing_load):
            # total load to redistribute from all failing nodes
            # distribute alpha fraction according to adjacency
            load_vec = np.zeros(n, dtype=float)
            load_vec[to_fail] = outgoing_load
            redistributed = ALPHA_REDISPATCH * (load_vec @ ADJ_MATRIX)
            loads += redistributed

    return failed, loads


# -------------------------------
# 5. Indicator computation
# -------------------------------

def compute_indicators(world: WorldState) -> Dict[str, float]:
    """Compute pre-shock indicators on the baseline state."""
    ratios = world.base_loads / world.capacities
    mean_ratio = float(np.mean(ratios))
    tail_fraction = float(np.mean(ratios > NEAR_THRESHOLD_RATIO))
    core_ratio = float(np.mean(ratios[CORE_INDEX] / world.capacities[CORE_INDEX]))
    # simple centrality proxy: degree-weighted ratio
    centrality_weights = DEGREE / DEGREE.sum()
    centrality_weighted_ratio = float(np.sum(centrality_weights * ratios))

    return {
        "mean_ratio": mean_ratio,
        "tail_fraction": tail_fraction,
        "core_ratio": core_ratio,
        "centrality_weighted_ratio": centrality_weighted_ratio,
    }


# -------------------------------
# 6. Early-warning schemes
# -------------------------------

def scheme_global_mean(indicators: Dict[str, float]) -> bool:
    """
    Global mean scheme:
    - Warn when the global mean ratio is high.
    """
    return indicators["mean_ratio"] > 0.70


def scheme_tail_sensitive(indicators: Dict[str, float]) -> bool:
    """
    Tail-sensitive scheme:
    - Warn when many nodes are near threshold, or the mean is very high.
    """
    return (indicators["tail_fraction"] > 0.25) or (indicators["mean_ratio"] > 0.65)


def scheme_core_focused(indicators: Dict[str, float]) -> bool:
    """
    Core-focused scheme:
    - Warn when core nodes look stressed, or when high-centrality nodes
      are collectively close to capacity.
    """
    return (indicators["core_ratio"] > 0.72) or (
        indicators["centrality_weighted_ratio"] > 0.70 and indicators["tail_fraction"] > 0.15
    )


SCHEMES = {
    "global_mean": scheme_global_mean,
    "tail_sensitive": scheme_tail_sensitive,
    "core_focused": scheme_core_focused,
}


# -------------------------------
# 7. Scenario simulation
# -------------------------------

print("Simulating scenarios...")
records = []

for run_idx in range(N_RUNS):
    world = sample_world_state()
    indicators = compute_indicators(world)

    # simple shock model: hit one random node in the periphery with a moderate shock
    # and with small probability also hit a core node
    shock_nodes: List[int] = []
    periphery_node = random.choice(list(PERIPHERY_INDEX))
    shock_nodes.append(periphery_node)

    if random.random() < 0.2:
        core_node = random.choice(list(CORE_INDEX))
        shock_nodes.append(core_node)

    shock_size = random.uniform(0.3, 0.6)

    failed_mask, final_loads = run_cascade(world, shock_nodes, shock_size)
    cascade_size = float(np.mean(failed_mask))
    crash = cascade_size >= CRASH_FRACTION_THRESHOLD

    record = {
        "run": run_idx,
        "crash": int(crash),
        "cascade_size": cascade_size,
        "shock_size": shock_size,
        "n_shock_nodes": len(shock_nodes),
    }
    record.update(indicators)
    records.append(record)

df = pd.DataFrame.from_records(records)
base_crash_rate = df["crash"].mean()

print("Scenario simulation completed.")
print(f"- Observed crash rate across all runs: {base_crash_rate:.3f}")
print("")


# -------------------------------
# 8. Evaluate schemes and compute T_warning
# -------------------------------

def evaluate_scheme(df_scenarios: pd.DataFrame, name: str, warn_fn) -> Dict[str, float]:
    """Evaluate a single warning scheme and compute confusion matrix plus T_warning."""
    warn_flags = df_scenarios.apply(lambda row: warn_fn(row.to_dict()), axis=1).astype(int)
    crash_flags = df_scenarios["crash"].astype(int).values

    TP = int(((warn_flags == 1) & (crash_flags == 1)).sum())
    FP = int(((warn_flags == 1) & (crash_flags == 0)).sum())
    FN = int(((warn_flags == 0) & (crash_flags == 1)).sum())
    TN = int(((warn_flags == 0) & (crash_flags == 0)).sum())

    total = TP + FP + FN + TN

    FN_denom = TP + FN
    FP_denom = FP + TN

    FN_rate = FN / FN_denom if FN_denom > 0 else 0.0
    FP_rate = FP / FP_denom if FP_denom > 0 else 0.0

    warn_share = (TP + FP) / total if total > 0 else 0.0
    crash_rate = (TP + FN) / total if total > 0 else 0.0

    warn_crash_denom = TP + FP
    quiet_crash_denom = FN + TN

    warn_crash_rate = TP / warn_crash_denom if warn_crash_denom > 0 else 0.0
    quiet_crash_rate = FN / quiet_crash_denom if quiet_crash_denom > 0 else 0.0

    # tension components
    T_FN = 3.0 * FN_rate
    T_FP = 1.0 * FP_rate
    T_cal = abs(warn_crash_rate - crash_rate)

    T_warning = T_FN + T_FP + T_cal

    return {
        "scheme_name": name,
        "FN_rate": FN_rate,
        "FP_rate": FP_rate,
        "warn_share": warn_share,
        "crash_rate": crash_rate,
        "warn_crash_rate": warn_crash_rate,
        "quiet_crash_rate": quiet_crash_rate,
        "T_warning": T_warning,
        "TP": TP,
        "FP": FP,
        "FN": FN,
        "TN": TN,
    }


summary_rows = []
for name, fn in SCHEMES.items():
    summary_rows.append(evaluate_scheme(df, name, fn))

summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values("T_warning", ascending=True).reset_index(drop=True)

print("Summary table (sorted by T_warning, lower means better-calibrated warnings):")
print(summary_df[[
    "scheme_name",
    "FN_rate",
    "FP_rate",
    "warn_share",
    "crash_rate",
    "warn_crash_rate",
    "quiet_crash_rate",
    "T_warning",
]])
print("")

print("Quick interpretation:")
for _, row in summary_df.iterrows():
    name = row["scheme_name"]
    T = row["T_warning"]
    FN_rate = row["FN_rate"]
    FP_rate = row["FP_rate"]
    print(
        f"- {name}: T_warning ≈ {T:.3f}, "
        f"FN_rate ≈ {FN_rate:.2f}, FP_rate ≈ {FP_rate:.2f}"
    )

print("")


# -------------------------------
# 9. Plots
# -------------------------------

# 9.1 Crash probability vs mean_ratio (effective indicator view)

bins = np.linspace(df["mean_ratio"].min(), df["mean_ratio"].max(), 8)
bin_indices = np.digitize(df["mean_ratio"].values, bins, right=True)

bin_centers = []
crash_probs = []
counts = []

for b in range(1, len(bins) + 1):
    mask = bin_indices == b
    if not np.any(mask):
        continue
    bin_centers.append(df["mean_ratio"].values[mask].mean())
    crash_probs.append(df["crash"].values[mask].mean())
    counts.append(mask.sum())

plt.figure(figsize=(8, 4.5))
plt.plot(bin_centers, crash_probs, marker="o")
plt.xlabel("Mean load / capacity ratio")
plt.ylabel("Crash probability in bin")
plt.title("TU Q105_A · Crash probability vs mean load ratio")
plt.grid(True, alpha=0.3)
indicator_plot_path = "Q105A_indicator_vs_crash.png"
plt.tight_layout()
plt.savefig(indicator_plot_path, dpi=150)
plt.show()
print(f"Saved indicator vs crash plot as: {indicator_plot_path}")

# 9.2 Bar plot of T_warning per scheme

plt.figure(figsize=(6, 4.5))
plt.bar(summary_df["scheme_name"], summary_df["T_warning"])
plt.xlabel("Early-warning scheme")
plt.ylabel("T_warning (higher = more tension)")
plt.title("TU Q105_A · T_warning per scheme")
plt.tight_layout()
tension_bar_path = "Q105A_T_warning.png"
plt.savefig(tension_bar_path, dpi=150)
plt.show()
print(f"Saved T_warning bar plot as: {tension_bar_path}")
