# Imports

In [None]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import pandas as pd
from typing import List, Tuple
from scipy.optimize import linear_sum_assignment

In [None]:
# -----------------------------------------------------------------------------
# Development imports with forced local package resolution and hot-reload
# -----------------------------------------------------------------------------

import sys
from pathlib import Path
import importlib

# Ensure the repository root (the directory that contains "domino/") is on sys.path.
# This makes "import domino" work even when running from a notebook folder.
REPO_ROOT = Path.cwd()
if not (REPO_ROOT / "src" / "domino").is_dir():
    # If the notebook is inside a subfolder, walk up until we find "domino/"
    for parent in Path.cwd().parents:
        if (parent / "src" / "domino").is_dir():
            REPO_ROOT = parent
            break

SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

# -----------------------------------------------------------------------------
# Hot-reload robusto (evita "module X not in sys.modules")
# -----------------------------------------------------------------------------
import importlib
import sys

def reload_module(modname: str):
    """
    Reload a module by name, importing it first if needed.
    This avoids inconsistencies between local variables and sys.modules.
    """
    m = importlib.import_module(modname)
    return importlib.reload(m)

# Reload in a safe order (core utilities first, then higher-level modules)
reload_module("domino.utils.constants")
reload_module("domino.utils.repro")

reload_module("domino.leiden.partitions_functions")
reload_module("domino.leiden.scoring")
reload_module("domino.leiden.leiden_engine")

reload_module("domino.bic_minimization.binary_bic")
reload_module("domino.bic_minimization.signed_bic")
reload_module("domino.bic_minimization.weighted_bic")

reload_module("domino.ergms_solvers.binary_solvers")
reload_module("domino.ergms_solvers.signed_solvers")
reload_module("domino.ergms_solvers.weighted_solvers")

reload_module("domino.represent_and_analyze")
reload_module("domino.detect")  # always reload by full name

from domino.detect import detect
from domino.represent_and_analyze import partition_to_dict, community_layout, process_graph

# Useful functions

In [None]:
def plot_ground_truth_with_domino(
    G: nx.Graph,
    truth: np.ndarray,
    *,
    pos: dict | None = None,
    Apos: np.ndarray | None = None,
    Aneg: np.ndarray | None = None,
    layout: str = "custom",
    title_prefix: str = "groundtruth",
) -> dict:
    """
    Render a ground-truth partition using the same post-processing and
    visualization code employed by detect().

    Parameters
    ----------
    G : nx.Graph
        Graph to be visualized.
    truth : np.ndarray
        Community labels aligned with node order 0..N-1.
    pos : dict, optional
        Fixed node positions to enforce identical geometry across plots.
    Apos, Aneg : np.ndarray, optional
        Signed layers used to trigger signed plotting.
    layout : str
        Layout keyword used by the package ("custom", "kamada", "kshell", "auto", "community").
    title_prefix : str
        Filename prefix used by the visualization backend (if saving is enabled).

    Returns
    -------
    dict
        Output dictionary returned by process_graph (colors, figures, etc.).
    """
    viz_cfg: dict = {
        "enabled": True,
        "layout": layout,
        "prefix": title_prefix,
    }
    if pos is not None:
        viz_cfg["layout"] = "custom"
        viz_cfg["pos"] = pos

    # process_graph expects labels in any supported format; (N,1) is fine.
    detected_labels = np.asarray(truth, dtype=int).reshape(-1, 1)

    return process_graph(
        G,
        detected_labels,
        Apos=Apos,
        Aneg=Aneg,
        viz=viz_cfg,
        report=False,
    )

def truth_as_dict(truth: np.ndarray) -> dict[int, int]:
    """
    Convert an array of ground-truth labels aligned with node order 0..N-1
    into a node -> community dictionary.
    """
    t = np.asarray(truth, dtype=int).reshape(-1)
    return {int(i): int(t[i]) for i in range(len(t))}


def _labels_from_any(G: nx.Graph, labels_like) -> np.ndarray:
    """
    Map any supported label container into a (N,) integer array aligned with sorted(G.nodes()).

    Supported inputs:
      - dict: node -> label
      - list / np.ndarray: labels in node order 0..N-1 (assumed)
      - Partition-like: iterable of communities (iterables of nodes)
    """
    nodes = sorted(G.nodes())
    if isinstance(labels_like, dict):
        return np.array([int(labels_like[n]) for n in nodes], dtype=int)

    if isinstance(labels_like, (list, np.ndarray)):
        y = np.asarray(labels_like, dtype=int).reshape(-1)
        # If labels are provided as length-N aligned with node ids 0..N-1,
        # sorted(nodes) is also [0..N-1] in all your synthetic tests.
        return y

    d = {}
    for cid, comm in enumerate(labels_like):
        for v in comm:
            d[v] = cid
    return np.array([int(d[n]) for n in nodes], dtype=int)


def _relabel_consecutive(y: np.ndarray) -> np.ndarray:
    """
    Relabel integer labels to consecutive {0,1,...,K-1} preserving equivalence classes.
    """
    _, inv = np.unique(np.asarray(y, dtype=int), return_inverse=True)
    return inv


def _contingency(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    """
    Compute contingency matrix C where C[a,b] counts nodes with true=a and pred=b,
    after relabeling each set to consecutive integers.
    """
    yt = _relabel_consecutive(y_true)
    yp = _relabel_consecutive(y_pred)
    kt = int(yt.max()) + 1
    kp = int(yp.max()) + 1
    C = np.zeros((kt, kp), dtype=int)
    for a, b in zip(yt, yp):
        C[a, b] += 1
    return C


def compare_partitions(G: nx.Graph, truth_labels, found_labels) -> dict:
    """
    Compare two partitions in a label-invariant way via optimal label matching.

    Returns a dictionary containing exact match flag, node accuracy, mapping between
    predicted and true labels (after Hungarian matching), contingency matrix and
    (if available) ARI/NMI.
    """
    y_true = _labels_from_any(G, truth_labels)
    y_pred = _labels_from_any(G, found_labels)
    C = _contingency(y_true, y_pred)

    # Hungarian matching to maximize matched nodes
    cost = C.max() - C
    r_ind, c_ind = linear_sum_assignment(cost)
    matched = int(C[r_ind, c_ind].sum())
    acc = matched / float(len(y_true))

    mapping_pred_to_true = {int(pred): int(true) for true, pred in zip(r_ind, c_ind)}
    exact = (C.shape[0] == C.shape[1]) and (acc == 1.0)

    ari = nmi = None
    try:
        from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
        ari = float(adjusted_rand_score(_relabel_consecutive(y_true), _relabel_consecutive(y_pred)))
        nmi = float(normalized_mutual_info_score(_relabel_consecutive(y_true), _relabel_consecutive(y_pred)))
    except Exception:
        pass

    return {
        "exact_match": bool(exact),
        "accuracy": float(acc),
        "mapping_pred_to_true": mapping_pred_to_true,
        "ari": ari,
        "nmi": nmi,
        "contingency": C,
    }


def pretty_print_result(name: str, res: dict) -> None:
    """
    Print a compact comparison report produced by compare_partitions().
    """
    print("\n" + "=" * 78)
    print(f"{name} — Ground truth vs. detected")
    print("=" * 78)
    print(f"Exact up to permutation : {res['exact_match']}")
    print(f"Node accuracy           : {res['accuracy'] * 100:.2f}%")
    if res["ari"] is not None:
        print(f"Adjusted Rand Index     : {res['ari']:.4f}")
    if res["nmi"] is not None:
        print(f"Normalized Mutual Info  : {res['nmi']:.4f}")

    mapping_str = ", ".join([f"P{p}→T{t}" for p, t in sorted(res["mapping_pred_to_true"].items())])
    print(f"Best label mapping      : {mapping_str if mapping_str else '(none)'}")

    C = res["contingency"]
    df = pd.DataFrame(
        C,
        index=[f"T{i}" for i in range(C.shape[0])],
        columns=[f"P{j}" for j in range(C.shape[1])],
    )
    print("\nContingency (rows=true, cols=pred):")
    print(df.to_string())


def recap_line(prefix: str, res: dict) -> str:
    """
    Build a one-line recap string for a compare_partitions() output.
    """
    s = f"{prefix}: acc={res['accuracy'] * 100:.1f}%"
    if res["ari"] is not None:
        s += f", ARI={res['ari']:.3f}"
    if res["nmi"] is not None:
        s += f", NMI={res['nmi']:.3f}"
    s += f", exact={res['exact_match']}"
    return s

# Binary

## Test Networks

### Create Test Networks

In [None]:
# Create networks to test the community detection algorithms

###############################################################################
# Five Disconnected Blocks (100 nodes => 5 blocks of 20 each)
###############################################################################

def create_five_disconnected_blocks(num_nodes=100, num_blocks=5, p_in=0.8):
    assert num_nodes % num_blocks == 0, "num_nodes must be divisible by num_blocks"
    block_size = num_nodes // num_blocks
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    ground_truth = np.zeros(num_nodes, dtype=int)
    for b in range(num_blocks):
        start = b * block_size
        end = (b+1) * block_size
        for i in range(start, end):
            ground_truth[i] = b
            for j in range(i+1, end):
                if np.random.rand() < p_in:
                    G.add_edge(i, j)
    return G, ground_truth

###############################################################################
# Ten-Block Connected Graph (Various Densities + Sparse Interconnections)
###############################################################################

def create_ten_blocks_connected(block_sizes=None, p_in=None, p_out=0.01):
    """
    Create a graph with 10 blocks, each with its own connection probability.
    
    Parameters:
        block_sizes (list of int, optional): Sizes of the 10 blocks. If None, defaults to 10 blocks of 30 nodes each.
        p_in (list of float, optional): Internal connection probabilities for each block. 
                                        If None, defaults to predefined densities.
        p_out (float): Probability of connections between different blocks.

    Returns:
        G (networkx.Graph): The generated graph.
        ground_truth (np.array): The ground-truth labels for each node.
    """
    if block_sizes is None:
        block_sizes = [30] * 10  # Default: 10 blocks of size 30

    if p_in is None:
        # Different probabilities for different densities (customizable)
        p_in = [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.6, 0.5, 0.4, 0.3] 
    
    assert len(block_sizes) == 10, "Must have 10 block sizes"
    assert len(p_in) == 10, "Must specify exactly 10 internal connection probabilities"

    N = sum(block_sizes)
    G = nx.Graph()
    G.add_nodes_from(range(N))
    ground_truth = np.zeros(N, dtype=int)

    # Track block indices
    start_indices = np.cumsum([0] + block_sizes[:-1])
    end_indices = np.cumsum(block_sizes)

    # Add intra-block edges
    for b in range(10):
        start, end = start_indices[b], end_indices[b]
        for i in range(start, end):
            ground_truth[i] = b
            for j in range(i + 1, end):
                if np.random.rand() < p_in[b]:
                    G.add_edge(i, j)

    # Add inter-block edges
    for b1 in range(10):
        for b2 in range(b1 + 1, 10):
            start1, end1 = start_indices[b1], end_indices[b1]
            start2, end2 = start_indices[b2], end_indices[b2]
            for i in range(start1, end1):
                for j in range(start2, end2):
                    if np.random.rand() < p_out:
                        G.add_edge(i, j)

    return G, ground_truth

###############################################################################
# Generate graphs
###############################################################################

G5, truth5 = create_five_disconnected_blocks(num_nodes=100, num_blocks=5, p_in=0.5)
G10, truth10 = create_ten_blocks_connected()

###############################################################################
# Plot graphs
###############################################################################

# Plot ground-truth communities for each graph:
pos_G5 = nx.spring_layout(G5, seed=42)
pos_G10 = nx.spring_layout(G10, seed=42)
plot_ground_truth_with_domino(G5, truth5, pos=pos_G5, layout="custom")
plot_ground_truth_with_domino(G10, truth10, pos=pos_G10, layout="custom")

# Define the adjacency matrices
A5 = nx.to_numpy_array(G5)
A10 = nx.to_numpy_array(G10)

### Community Detection

#### SBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# G5: five disconnected blocks
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G5")
print("========================================\n")

G5_bin = nx.from_numpy_array(A5, create_using=nx.Graph)

res = detect(
    G5_bin,
    A5,
    mode="binary",
    degree_corrected=False,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G5},
    report={"print_info": True},
)

sbm_part_5 = res["partition"]
sbm_bic_5 = res["bic"]
print(f"G5 best BIC: {sbm_bic_5:.2f}")

sbm_labels_5 = partition_to_dict(sbm_part_5.flatten())

# ─────────────────────────────────────────────────────────────────────────────
# G10: ten‐block connected
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G10")
print("========================================\n")

G10_bin = nx.from_numpy_array(A10, create_using=nx.Graph)

res = detect(
    G10_bin,
    A10,
    mode="binary",
    degree_corrected=False,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G10},
    report={"print_info": True},
)

sbm_part_10 = res["partition"]
sbm_bic_10 = res["bic"]
print(f"G10 best BIC: {sbm_bic_10:.2f}")

sbm_labels_10 = partition_to_dict(sbm_part_10.flatten())

#### Check partitions

In [None]:
# -------- Binary: compare truth vs detected (SBM) --------
truth_G5 = truth_as_dict(truth5)
res_G5_sbm = compare_partitions(G5_bin, truth_G5, sbm_labels_5)
pretty_print_result("G5 (binary 5-blocks, SBM)", res_G5_sbm)

truth_G10 = truth_as_dict(truth10)
res_G10_sbm = compare_partitions(G10_bin, truth_G10, sbm_labels_10)
pretty_print_result("G10 (binary 10-blocks, SBM)", res_G10_sbm)

print("\n" + "-" * 78)
print("[Recap SBM]   " + recap_line("G5", res_G5_sbm))
print("[Recap SBM]   " + recap_line("G10", res_G10_sbm))
print("-" * 78)

#### dcSBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# G5: five disconnected blocks
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G5")
print("========================================\n")

G5_bin = nx.from_numpy_array(A5, create_using=nx.Graph)

res = detect(
    G5_bin,
    A5,
    mode="binary",
    degree_corrected=True,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G5},
    report={"print_info": True},
)

dcsbm_part_5 = res["partition"]
dcsbm_bic_5 = res["bic"]
print(f"G5 best BIC: {dcsbm_bic_5:.2f}")

dcsbm_labels_5 = partition_to_dict(dcsbm_part_5.flatten())

# ─────────────────────────────────────────────────────────────────────────────
# G10: ten‐block connected
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G10")
print("========================================\n")

G10_bin = nx.from_numpy_array(A10, create_using=nx.Graph)

res = detect(
    G10_bin,
    A10,
    mode="binary",
    degree_corrected=True,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G10},
    report={"print_info": True},
)

dcsbm_part_10 = res["partition"]
dcsbm_bic_10 = res["bic"]
print(f"G10 best BIC: {dcsbm_bic_10:.2f}")

dcsbm_labels_10 = partition_to_dict(dcsbm_part_10.flatten())

#### Check partitions

In [None]:
# -------- Binary: compare truth vs detected (dcSBM) --------
truth_G5 = truth_as_dict(truth5)
res_G5_dcsbm = compare_partitions(G5_bin, truth_G5, dcsbm_labels_5)
pretty_print_result("G5 (binary 5-blocks, dcSBM)", res_G5_dcsbm)

truth_G10 = truth_as_dict(truth10)
res_G10_dcsbm = compare_partitions(G10_bin, truth_G10, dcsbm_labels_10)
pretty_print_result("G10 (binary 10-blocks, dcSBM)", res_G10_dcsbm)

print("\n" + "-" * 78)
print("[Recap dcSBM] " + recap_line("G5", res_G5_dcsbm))
print("[Recap dcSBM] " + recap_line("G10", res_G10_dcsbm))
print("-" * 78)

# Signed

## Test Networks

### Create Test Networks

In [None]:
# =============================================================================
# Utilities: sample signed SBM with tri-nomial dyads (+, −, 0)
# =============================================================================
def _block_ranges(sizes: List[int]) -> List[Tuple[int, int]]:
    """Return [(start, end), ...] ranges for each block (0-indexed, end-exclusive)."""
    starts = np.cumsum([0] + sizes[:-1])
    ends = np.cumsum(sizes)
    return list(zip(starts, ends))


def _sample_signed_blockpair(
    Ap: np.ndarray,
    An: np.ndarray,
    i_range: Tuple[int, int],
    j_range: Tuple[int, int],
    p_pos: float,
    p_neg: float,
    same_block: bool,
    rng: np.random.Generator
) -> None:
    """
    Fill Ap/An for block-pair with tri-nomial sampling.
    Exactly one of {+, −, 0} per dyad.

    Ap/An are symmetric 0/1 adjacency matrices (no self-loops).
    """
    i0, i1 = i_range
    j0, j1 = j_range
    if same_block:
        # iterate strictly upper triangle within the block
        for i in range(i0, i1):
            for j in range(i + 1, i1):
                u = rng.random()
                if u < p_pos:
                    Ap[i, j] = Ap[j, i] = 1
                elif u < p_pos + p_neg:
                    An[i, j] = An[j, i] = 1
        return

    # r != s: full bipartite rectangle
    for i in range(i0, i1):
        for j in range(j0, j1):
            u = rng.random()
            if u < p_pos:
                Ap[i, j] = Ap[j, i] = 1
            elif u < p_pos + p_neg:
                An[i, j] = An[j, i] = 1


def sample_signed_sbm(
    block_sizes: List[int],
    Ppos: np.ndarray,  # shape (B,B), symmetric, entries in [0,1]
    Pneg: np.ndarray,  # shape (B,B), symmetric, entries in [0,1], with Ppos+Pneg ≤ 1
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Sample a signed SBM with B blocks of given sizes and block-pair probabilities.

    Returns
    -------
    Apos, Aneg, truth
      Apos/Aneg: NxN (0/1) symmetric +/− adjacency matrices, no self-loops.
      truth    : (N,) block labels in {0..B-1}.
    """
    assert Ppos.shape == Pneg.shape
    B = Ppos.shape[0]
    assert B == len(block_sizes), "P matrices and block_sizes disagree"
    assert np.allclose(Ppos, Ppos.T) and np.allclose(Pneg, Pneg.T), "Ppos/Pneg must be symmetric"
    assert np.all(Ppos >= 0) and np.all(Pneg >= 0), "negative probabilities!"
    assert np.all(Ppos + Pneg <= 1 + 1e-12), "Ppos + Pneg must be ≤ 1"

    N = int(np.sum(block_sizes))
    Apos = np.zeros((N, N), dtype=int)
    Aneg = np.zeros((N, N), dtype=int)

    # ground-truth labels
    truth = np.empty(N, dtype=int)
    ranges = _block_ranges(block_sizes)
    for r, (a, b) in enumerate(ranges):
        truth[a:b] = r

    rng = np.random.default_rng(seed)

    # fill per block pair
    for r in range(B):
        # diagonal
        _sample_signed_blockpair(
            Apos, Aneg, ranges[r], ranges[r], Ppos[r, r], Pneg[r, r], True, rng
        )
        # off-diagonals
        for s in range(r + 1, B):
            _sample_signed_blockpair(
                Apos, Aneg, ranges[r], ranges[s], Ppos[r, s], Pneg[r, s], False, rng
            )

    # zero diagonal
    np.fill_diagonal(Apos, 0)
    np.fill_diagonal(Aneg, 0)
    return Apos, Aneg, truth


# =============================================================================
# Easy case: 5 disconnected “friend groups” with negative between-group ties
# =============================================================================
def create_easy_signed_5blocks(
    num_blocks: int = 5,
    block_size: int = 20,
    p_in_pos: float = 0.60,
    p_in_neg: float = 0.02,
    p_out_pos: float = 0.03,
    p_out_neg: float = 0.25,
    seed: int = 42
):
    """
    Five equal communities (strongly positive inside, mostly negative across).
    """
    sizes = [block_size] * num_blocks
    B = num_blocks

    Ppos = np.full((B, B), p_out_pos, dtype=float)
    Pneg = np.full((B, B), p_out_neg, dtype=float)

    # stronger + inside; almost no − inside
    np.fill_diagonal(Ppos, p_in_pos)
    np.fill_diagonal(Pneg, p_in_neg)

    # ensure Ppos+Pneg ≤ 1
    assert np.all(Ppos + Pneg <= 1.0 + 1e-12)

    Apos, Aneg, truth = sample_signed_sbm(sizes, Ppos, Pneg, seed=seed)
    return Apos, Aneg, truth


# =============================================================================
# Harder case: 8 blocks, two macro-alliances with exceptions (heterogeneous)
# =============================================================================
def create_hard_signed_8blocks(
    sizes: List[int] = None,
    seed: int = 7
):
    """
    8 blocks (≤ 300 total nodes) with two macro-alliances:
      Group A = {0,1,2,3}, Group B = {4,5,6,7}
    - Inside blocks: moderate-high positive density, tiny negative.
    - Within the same macro-alliance: modest positive, tiny negative.
    - Across alliances: weak positive, stronger negative.
    - A few 'exception edges' flip the sign pattern to add difficulty.
    """
    if sizes is None:
        sizes = [36, 30, 28, 26, 28, 30, 32, 30]  # sum = 240

    B = 8
    assert len(sizes) == B
    Ppos = np.zeros((B, B), dtype=float)
    Pneg = np.zeros((B, B), dtype=float)

    # Baselines
    for r in range(B):
        for s in range(B):
            if r == s:
                Ppos[r, s] = 0.55  # within-block + edges
                Pneg[r, s] = 0.02
            else:
                # default cross-block
                Ppos[r, s] = 0.05
                Pneg[r, s] = 0.10

    # Macro-alliances: {0,1,2,3} vs {4,5,6,7}
    A_side = {0, 1, 2, 3}
    B_side = {4, 5, 6, 7}
    for r in A_side:
        for s in A_side:
            if r < s:
                Ppos[r, s] = Ppos[s, r] = 0.18
                Pneg[r, s] = Pneg[s, r] = 0.03
    for r in B_side:
        for s in B_side:
            if r < s:
                Ppos[r, s] = Ppos[s, r] = 0.17
                Pneg[r, s] = Pneg[s, r] = 0.03

    # Across alliances: mostly negative rivalry
    for r in A_side:
        for s in B_side:
            Ppos[r, s] = Ppos[s, r] = 0.03
            Pneg[r, s] = Pneg[s, r] = 0.18

    # A few deliberate exceptions to break the clean pattern:
    # strong + between blocks (2,5) and (1,4); strong − within pair (3,2)
    Ppos[2, 5] = Ppos[5, 2] = 0.20
    Pneg[2, 5] = Pneg[5, 2] = 0.02

    Ppos[1, 4] = Ppos[4, 1] = 0.20
    Pneg[1, 4] = Pneg[4, 1] = 0.02

    # a strong negative tie between (2,3) despite same macro-alliance
    Ppos[2, 3] = Ppos[3, 2] = 0.06
    Pneg[2, 3] = Pneg[3, 2] = 0.20

    # sanity: Ppos + Pneg ≤ 1
    assert np.all(Ppos + Pneg <= 1.0 + 1e-12)

    Apos, Aneg, truth = sample_signed_sbm(sizes, Ppos, Pneg, seed=seed)
    return Apos, Aneg, truth

# ------------------ Generate graphs ------------------

# Easy signed (5 blocks)
Apos5, Aneg5, truth5 = create_easy_signed_5blocks(
    num_blocks=5, block_size=20,
    p_in_pos=0.60, p_in_neg=0.02,
    p_out_pos=0.03, p_out_neg=0.25,
    seed=123
)

# Hard signed (8 blocks)
Apos8, Aneg8, truth8 = create_hard_signed_8blocks()

# Build union graphs (edge if + or −), as in your binary case’s use of nx.from_numpy_array
Asign5 = Apos5 - Aneg5
Asign8 = Apos8 - Aneg8
G5_signed = nx.from_numpy_array((Asign5 != 0).astype(int), create_using=nx.Graph)
G8_signed = nx.from_numpy_array((Asign8 != 0).astype(int), create_using=nx.Graph)

# Plot ground truth (exactly like you do for binary)
pos_S5 = community_layout(G5_signed, {i: int(truth5[i]) for i in range(len(truth5))}, seed=42)
pos_S8 = community_layout(G8_signed, {i: int(truth8[i]) for i in range(len(truth8))}, seed=7)
plot_ground_truth_with_domino(G5_signed, truth5, pos=pos_S5, Apos=Apos5, Aneg=Aneg5, layout="custom");
plot_ground_truth_with_domino(G8_signed, truth8, pos=pos_S8, Apos=Apos8, Aneg=Aneg8, layout="custom");

### Community Detection

#### sSBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# S5: five signed blocks  (analog of your G5 section)
# ─────────────────────────────────────────────────────────────────────────────
print("\n========================================")
print("Processing S5 (signed 5-blocks)")
print("========================================\n")

res = detect(
    G5_signed,
    Asign5,
    mode="signed",
    degree_corrected=False,
    initial_partition="pos_modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_S5, "Apos": Apos5, "Aneg": Aneg5},
    report={"print_info": True},
)

sSBM_part_5 = res["partition"]
sSBM_bic_5 = res["bic"]
print(f"S5 best BIC (sSBM): {sSBM_bic_5:.2f}")

sSBM_labels_5 = partition_to_dict(sSBM_part_5.flatten())

# ─────────────────────────────────────────────────────────────────────────────
# S8: eight signed blocks  (analog of your G10 section)
# ─────────────────────────────────────────────────────────────────────────────
print("\n========================================")
print("Processing S8 (signed 8-blocks)")
print("========================================\n")

res = detect(
    G8_signed,
    Asign8,
    mode="signed",
    degree_corrected=False,
    initial_partition="pos_modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_S8, "Apos": Apos8, "Aneg": Aneg8},
    report={"print_info": True},
)

sSBM_part_8 = res["partition"]
sSBM_bic_8 = res["bic"]
print(f"S8 best BIC (sSBM): {sSBM_bic_8:.2f}")

sSBM_labels_8 = partition_to_dict(sSBM_part_8.flatten())

#### Check partitions

In [None]:
# -------- S5 comparison --------
truth_S5 = {i: int(truth5[i]) for i in range(len(truth5))}
# Use your detected labels (dict or Partition). You already built sSBM_labels_5:
res_S5 = compare_partitions(G5_signed, truth_S5, sSBM_labels_5)
pretty_print_result("S5 (signed 5-blocks, sSBM)", res_S5)

# -------- S8 comparison --------
truth_S8 = {i: int(truth8[i]) for i in range(len(truth8))}
# You already built sSBM_labels_8:
res_S8 = compare_partitions(G8_signed, truth_S8, sSBM_labels_8)
pretty_print_result("S8 (signed 8-blocks, sSBM)", res_S8)

# -------- Compact one-line recap --------
print("\n" + "-"*78)
print(f"[Recap] S5: acc={res_S5['accuracy']*100:.1f}%"
      + (f", ARI={res_S5['ari']:.3f}" if res_S5['ari'] is not None else "")
      + (f", NMI={res_S5['nmi']:.3f}" if res_S5['nmi'] is not None else "")
      + f", exact={res_S5['exact_match']}")
print(f"[Recap] S8: acc={res_S8['accuracy']*100:.1f}%"
      + (f", ARI={res_S8['ari']:.3f}" if res_S8['ari'] is not None else "")
      + (f", NMI={res_S8['nmi']:.3f}" if res_S8['nmi'] is not None else "")
      + f", exact={res_S8['exact_match']}")
print("-"*78)

#### sdcSBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# S5: five signed blocks  (analog of your G5 section)
# ─────────────────────────────────────────────────────────────────────────────
print("\n========================================")
print("Processing S5 (signed 5-blocks)")
print("========================================\n")

res = detect(
    G5_signed,
    Asign5,
    mode="signed",
    degree_corrected=True,
    initial_partition="pos_modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_S5, "Apos": Apos5, "Aneg": Aneg5},
    report={"print_info": True},
)

sdcSBM_part_5 = res["partition"]
sdcSBM_bic_5 = res["bic"]
print(f"S5 best BIC (sdcSBM): {sdcSBM_bic_5:.2f}")

sdcSBM_labels_5 = partition_to_dict(sdcSBM_part_5.flatten())

# ─────────────────────────────────────────────────────────────────────────────
# S8: eight signed blocks  (analog of your G10 section)
# ─────────────────────────────────────────────────────────────────────────────
print("\n========================================")
print("Processing S8 (signed 8-blocks)")
print("========================================\n")

res = detect(
    G8_signed,
    Asign8,
    mode="signed",
    degree_corrected=True,
    initial_partition="pos_modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_S8, "Apos": Apos8, "Aneg": Aneg8},
    report={"print_info": True},
)

sdcSBM_part_8 = res["partition"]
sdcSBM_bic_8 = res["bic"]
print(f"S8 best BIC (sdcSBM): {sdcSBM_bic_8:.2f}")

sdcSBM_labels_8 = partition_to_dict(sdcSBM_part_8.flatten())

#### Check partitions

In [None]:
# -------- S5 (sdcSBM) --------
truth_S5 = {i: int(truth5[i]) for i in range(len(truth5))}
res_S5_sdc = compare_partitions(G5_signed, truth_S5, sdcSBM_labels_5)
pretty_print_result("S5 (signed 5-blocks, sdcSBM)", res_S5_sdc)

# -------- S8 (sdcSBM) --------
truth_S8 = {i: int(truth8[i]) for i in range(len(truth8))}
res_S8_sdc = compare_partitions(G8_signed, truth_S8, sdcSBM_labels_8)
pretty_print_result("S8 (signed 8-blocks, sdcSBM)", res_S8_sdc)

# -------- Compact one-line recap (sdcSBM) --------
print("\n" + "-"*78)
print(f"[Recap sdcSBM] S5: acc={res_S5_sdc['accuracy']*100:.1f}%"
      + (f", ARI={res_S5_sdc['ari']:.3f}" if res_S5_sdc['ari'] is not None else "")
      + (f", NMI={res_S5_sdc['nmi']:.3f}" if res_S5_sdc['nmi'] is not None else "")
      + f", exact={res_S5_sdc['exact_match']}")
print(f"[Recap sdcSBM] S8: acc={res_S8_sdc['accuracy']*100:.1f}%"
      + (f", ARI={res_S8_sdc['ari']:.3f}" if res_S8_sdc['ari'] is not None else "")
      + (f", NMI={res_S8_sdc['nmi']:.3f}" if res_S8_sdc['nmi'] is not None else "")
      + f", exact={res_S8_sdc['exact_match']}")
print("-"*78)

# Weighted

## Test Networks

### Create Test Networks

In [None]:
###############################################################################
# Helpers
###############################################################################

def _geom_sample(mean):
    """
    Geometric on {0,1,2,...} with mean=mean.
    Parameterization: p = 1/(1+mean); w = Geom(p on {1,2,...}) - 1
    """
    if mean <= 0:
        return 0
    p = 1.0 / (1.0 + float(mean))
    return np.random.geometric(p) - 1

###############################################################################
# Five Disconnected Blocks (weighted)
###############################################################################

def create_five_disconnected_blocks_weighted(
    num_nodes=100, num_blocks=5, mu_in=1.5
):
    """
    Create a weighted graph with 'num_blocks' disconnected blocks.
    For i<j within the same block, w_ij ~ Geometric(mean=mu_in) on {0,1,2,...}.
    Cross-block weights are 0 (disconnected).
    """
    assert num_nodes % num_blocks == 0, "num_nodes must be divisible by num_blocks"
    block_size = num_nodes // num_blocks
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    ground_truth = np.zeros(num_nodes, dtype=int)

    for b in range(num_blocks):
        start = b * block_size
        end = (b + 1) * block_size
        for i in range(start, end):
            ground_truth[i] = b
            for j in range(i + 1, end):
                w = _geom_sample(mu_in)
                if w > 0:
                    G.add_edge(i, j, weight=int(w))
    return G, ground_truth

###############################################################################
# Ten-Block Connected Graph (weighted; heterogeneous within means + sparse cross)
###############################################################################

def create_ten_blocks_connected_weighted(
    block_sizes=None,
    mu_in=None,
    mu_out=0.05
):
    """
    Create a weighted graph with 10 blocks. For i<j in block b, w_ij ~ Geom(mean=mu_in[b]).
    For cross-block pairs, w_ij ~ Geom(mean=mu_out), typically very small.
    """
    if block_sizes is None:
        block_sizes = [30] * 10

    if mu_in is None:
        # customizable within-block means (higher → denser/heavier)
        mu_in = [1.5, 1.2, 1.0, 0.8, 0.6, 0.4, 1.0, 0.8, 0.6, 0.4]

    assert len(block_sizes) == 10, "Must have 10 block sizes"
    assert len(mu_in) == 10, "Must specify exactly 10 within-block means"

    N = sum(block_sizes)
    G = nx.Graph()
    G.add_nodes_from(range(N))
    ground_truth = np.zeros(N, dtype=int)

    # Block boundaries
    start_indices = np.cumsum([0] + block_sizes[:-1])
    end_indices = np.cumsum(block_sizes)

    # Intra-block weights
    for b in range(10):
        start, end = start_indices[b], end_indices[b]
        for i in range(start, end):
            ground_truth[i] = b
            for j in range(i + 1, end):
                w = _geom_sample(mu_in[b])
                if w > 0:
                    G.add_edge(i, j, weight=int(w))

    # Inter-block weights (sparse / tiny means)
    for b1 in range(10):
        for b2 in range(b1 + 1, 10):
            start1, end1 = start_indices[b1], end_indices[b1]
            start2, end2 = start_indices[b2], end_indices[b2]
            for i in range(start1, end1):
                for j in range(start2, end2):
                    w = _geom_sample(mu_out)
                    if w > 0:
                        G.add_edge(i, j, weight=int(w))

    return G, ground_truth

###############################################################################
# Generate graphs (weighted)
###############################################################################

G5w, truth5w = create_five_disconnected_blocks_weighted(num_nodes=100, num_blocks=5, mu_in=1.5)
G10w, truth10w = create_ten_blocks_connected_weighted()

###############################################################################
# Plot graphs (weighted)
###############################################################################

pos_G5w = nx.spring_layout(G5w, seed=42)
pos_G10w = nx.spring_layout(G10w, seed=42)
plot_ground_truth_with_domino(G5w, truth5w, pos=pos_G5w, layout="custom", title_prefix="G5w_truth")
plot_ground_truth_with_domino(G10w, truth10w, pos=pos_G10w, layout="custom", title_prefix="G10w_truth")

# Weighted adjacency matrices
A5w  = nx.to_numpy_array(G5w,  weight='weight', dtype=float)
A10w = nx.to_numpy_array(G10w, weight='weight', dtype=float)

### Community Detection

#### wSBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# G5w: five disconnected blocks — Weighted SBM (no degree correction)
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G5 (weighted, wSBM)")
print("========================================\n")

G5w_nx = nx.from_numpy_array(A5w)  # weights included by default

res = detect(
    G5w_nx,
    A5w,
    mode="weighted",
    degree_corrected=False,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G5w},
    report={"print_info": True},
)

wsbm_part_5 = res["partition"]
wsbm_bic_5 = res["bic"]
print(f"G5 (weighted) best BIC (wSBM): {wsbm_bic_5:.2f}")

wsbm_labels_5 = partition_to_dict(wsbm_part_5.flatten())


# ─────────────────────────────────────────────────────────────────────────────
# G10w: ten-block connected — Weighted SBM (no degree correction)
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G10 (weighted, wSBM)")
print("========================================\n")

G10w_nx = nx.from_numpy_array(A10w)

res = detect(
    G10w_nx,
    A10w,
    mode="weighted",
    degree_corrected=False,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G10w},
    report={"print_info": True},
)

wsbm_part_10 = res["partition"]
wsbm_bic_10 = res["bic"]
print(f"G10 (weighted) best BIC (wSBM): {wsbm_bic_10:.2f}")

wsbm_labels_10 = partition_to_dict(wsbm_part_10.flatten())

#### Check partitions

In [None]:
# -------- G5w (wSBM) --------
truth_G5w = {i: int(truth5w[i]) for i in range(len(truth5w))}
res_G5w_wsbm = compare_partitions(G5w_nx, truth_G5w, wsbm_labels_5)
pretty_print_result("G5w (weighted 5-blocks, wSBM)", res_G5w_wsbm)

# -------- G10w (wSBM) --------
truth_G10w = {i: int(truth10w[i]) for i in range(len(truth10w))}
res_G10w_wsbm = compare_partitions(G10w_nx, truth_G10w, wsbm_labels_10)
pretty_print_result("G10w (weighted 10-blocks, wSBM)", res_G10w_wsbm)

print("\n" + "-"*78)
print(f"[Recap wSBM]  G5w: acc={res_G5w_wsbm['accuracy']*100:.1f}%"
      + (f", ARI={res_G5w_wsbm['ari']:.3f}" if res_G5w_wsbm['ari'] is not None else "")
      + (f", NMI={res_G5w_wsbm['nmi']:.3f}" if res_G5w_wsbm['nmi'] is not None else "")
      + f", exact={res_G5w_wsbm['exact_match']}")
print(f"[Recap wSBM]  G10w: acc={res_G10w_wsbm['accuracy']*100:.1f}%"
      + (f", ARI={res_G10w_wsbm['ari']:.3f}" if res_G10w_wsbm['ari'] is not None else "")
      + (f", NMI={res_G10w_wsbm['nmi']:.3f}" if res_G10w_wsbm['nmi'] is not None else "")
      + f", exact={res_G10w_wsbm['exact_match']}")
print("-"*78)

#### wdcSBM

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# G5w: five disconnected blocks — Weighted dcSBM
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G5 (weighted, wdcSBM)")
print("========================================\n")

res = detect(
    G5w_nx,
    A5w,
    mode="weighted",
    degree_corrected=True,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G5w},
    report={"print_info": True},
)

wdcsbm_part_5 = res["partition"]
wdcsbm_bic_5 = res["bic"]
print(f"G5 (weighted) best BIC (wdcSBM): {wdcsbm_bic_5:.2f}")

wdcsbm_labels_5 = partition_to_dict(wdcsbm_part_5.flatten())

# ─────────────────────────────────────────────────────────────────────────────
# G10w: ten-block connected — Weighted dcSBM
# ─────────────────────────────────────────────────────────────────────────────

print("\n========================================")
print("Processing G10 (weighted, wdcSBM)")
print("========================================\n")

res = detect(
    G10w_nx,
    A10w,
    mode="weighted",
    degree_corrected=True,
    initial_partition="modularity",
    max_outer=10,
    target_K=None,
    viz={"layout": "custom", "pos": pos_G10w},
    report={"print_info": True},
)

wdcsbm_part_10 = res["partition"]
wdcsbm_bic_10 = res["bic"]
print(f"G10 (weighted) best BIC (wdcSBM): {wdcsbm_bic_10:.2f}")

wdcsbm_labels_10 = partition_to_dict(wdcsbm_part_10.flatten())

#### Check partitions

In [None]:
# -------- G5w (wdcSBM) --------
res_G5w_wdcsbm = compare_partitions(G5w_nx, truth_G5w, wdcsbm_labels_5)
pretty_print_result("G5w (weighted 5-blocks, wdcSBM)", res_G5w_wdcsbm)

# -------- G10w (wdcSBM) --------
res_G10w_wdcsbm = compare_partitions(G10w_nx, truth_G10w, wdcsbm_labels_10)
pretty_print_result("G10w (weighted 10-blocks, wdcSBM)", res_G10w_wdcsbm)

print("\n" + "-"*78)
print(f"[Recap wdcSBM] G5w: acc={res_G5w_wdcsbm['accuracy']*100:.1f}%"
      + (f", ARI={res_G5w_wdcsbm['ari']:.3f}" if res_G5w_wdcsbm['ari'] is not None else "")
      + (f", NMI={res_G5w_wdcsbm['nmi']:.3f}" if res_G5w_wdcsbm['nmi'] is not None else "")
      + f", exact={res_G5w_wdcsbm['exact_match']}")
print(f"[Recap wdcSBM] G10w: acc={res_G10w_wdcsbm['accuracy']*100:.1f}%"
      + (f", ARI={res_G10w_wdcsbm['ari']:.3f}" if res_G10w_wdcsbm['ari'] is not None else "")
      + (f", NMI={res_G10w_wdcsbm['nmi']:.3f}" if res_G10w_wdcsbm['nmi'] is not None else "")
      + f", exact={res_G10w_wdcsbm['exact_match']}")
print("-"*78)