In [None]:
!pip install networkx matplotlib tqdm --quiet


In [1]:
# ============================================================
# Graph Coloring with GNN Guidance – "Best Results" Version
# ============================================================

import random
import time
import copy

import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------------------------------
# Global config & reproducibility
# ------------------------------------------------------------

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# ===================== USER-CONFIGURABLE =====================

# Model type: "gcn", "resgcn" (default, strongest), or "mlp" baseline
MODEL_TYPE = "resgcn"

# Whether to canonicalize colors per graph
# Keep this False so labels stay consistent across graphs.
USE_CANONICAL_COLORS = False

# Difficulty regimes
REGIMES = {
    "easy": {
        "n_range": (10, 18),
        "p_range": (0.20, 0.45),
        "hard_min": 5,
        "hard_max": 150,
    },
    "medium": {
        "n_range": (18, 30),
        "p_range": (0.15, 0.40),
        "hard_min": 20,
        "hard_max": 1000,
    },
    "hard": {
        "n_range": (25, 45),
        "p_range": (0.15, 0.35),
        "hard_min": 50,
        "hard_max": 5000,
    }
}

# Pick which regime to run
# For strongest “working” result, start with "easy".
REGIME = "easy"   # change to "medium" or "hard" for extra experiments

# Graph coloring / dataset config
K_COLORS = 4                 # number of colors solver must use

# Target number of graphs
NUM_TRAIN = 1000
NUM_VAL   = 200
NUM_TEST  = 200
NUM_TOTAL = NUM_TRAIN + NUM_VAL + NUM_TEST

MAX_LABEL_BACKTRACKS = 20000  # safety limit when generating labels

# Node features: degree, normalized degree, clustering, 2-hop degree norm
NODE_FEATURE_DIM = 4

# GNN training config
HIDDEN_DIM   = 256     # bigger model
NUM_LAYERS   = 6       # used by plain GCN; ResGCN has its own depth
LR           = 3e-3
WEIGHT_DECAY = 5e-5
EPOCHS       = 200
PATIENCE     = 30

# Violation loss: OFF by default for clean CE training,
# but you can turn this ON for ablations (especially on harder regimes).
USE_VIOLATION_LOSS = False
LAMBDA_VIOL        = 0.1

# Optional debug flag: overfit a single graph after dataset generation
RUN_SINGLE_GRAPH_DEBUG = False


# ============================================================
# 1. Classical backtracking solver
# ============================================================

def solve_graph_coloring(adj, k,
                         node_order=None,
                         init_colors=None,
                         color_order=None,
                         max_backtracks=None):
    """
    Simple backtracking graph coloring solver.
    - adj: numpy [N, N] 0/1
    - k: number of colors
    - node_order: list of node indices (variable ordering)
    - init_colors: optional array [N] of 'preferred' colors (-1 if none)
      (used only when color_order is None)
    - color_order: optional array [N, k] where color_order[v] is a
      permutation of [0..k-1] giving the order to try colors at node v.
    - max_backtracks: limit search effort
    Returns: (success: bool, assignment: np.ndarray[N], stats: dict)
    """
    adj = np.array(adj, dtype=np.float32)
    n = adj.shape[0]
    neighbors = [np.where(adj[v] > 0.5)[0] for v in range(n)]

    if node_order is None:
        node_order = list(range(n))
    else:
        node_order = list(node_order)

    assignment = np.full(n, -1, dtype=np.int64)
    if init_colors is not None:
        init_colors = np.array(init_colors, dtype=np.int64)
    backtracks = 0
    steps = 0

    def is_consistent(v, c):
        for u in neighbors[v]:
            if assignment[u] == c:
                return False
        return True

    def backtrack(pos):
        nonlocal backtracks, steps
        if pos == len(node_order):
            return True

        v = node_order[pos]

        # If a per-node color order is provided, use it.
        if color_order is not None:
            colors_to_try = list(color_order[v])
        else:
            # Default: 0..k-1, with optional preferred color first
            colors_to_try = list(range(k))
            if init_colors is not None and init_colors[v] != -1:
                preferred = int(init_colors[v])
                if preferred in colors_to_try:
                    colors_to_try.remove(preferred)
                    colors_to_try = [preferred] + colors_to_try

        for c in colors_to_try:
            steps += 1
            if is_consistent(v, c):
                assignment[v] = c
                if backtrack(pos + 1):
                    return True
                assignment[v] = -1
                backtracks += 1
                if max_backtracks is not None and backtracks >= max_backtracks:
                    return False
        return False

    success = backtrack(0)
    stats = {"backtracks": backtracks, "steps": steps}
    return success, assignment, stats


def check_valid_coloring(adj, colors, k):
    """
    Check that 'colors' is a valid k-coloring for adj.
    """
    adj = np.array(adj, dtype=np.float32)
    colors = np.array(colors, dtype=np.int64)
    n = adj.shape[0]

    if colors.shape[0] != n:
        return False
    if (colors < 0).any() or (colors >= k).any():
        return False

    edges = np.argwhere(adj > 0.5)
    for u, v in edges:
        if u < v and colors[u] == colors[v]:
            return False
    return True


# ============================================================
# 2. GraphData container & dataset generation with hardness
# ============================================================

class GraphData:
    """
    Container for one graph:
    - x: [N, F] float32 node features
    - adj: [N, N] float32 adjacency (0/1)
    - y: [N] int64 node color labels
    - difficulty: backtracks needed by degree-based solver that produced y
    """
    def __init__(self, x, adj, y, difficulty):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.adj = torch.tensor(adj, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        self.difficulty = difficulty

    @property
    def num_nodes(self):
        return self.x.size(0)


def canonicalize_colors(colors):
    """
    Canonicalize color IDs to 0..C-1 in order of first appearance in node index order.
    Only used if USE_CANONICAL_COLORS = True.
    """
    colors = list(colors)
    mapping = {}
    next_id = 0
    canon = []
    for c in colors:
        if c not in mapping:
            mapping[c] = next_id
            next_id += 1
        canon.append(mapping[c])
    return np.array(canon, dtype=np.int64)


def generate_colored_graph_with_solver(
    k_colors,
    n_range,
    p_range,
    min_backtracks,
    max_backtracks,
    label_max_backtracks=MAX_LABEL_BACKTRACKS
):
    """
    Generate a random Erdos–Renyi graph and:
    - run a degree-based backtracking solver with k_colors,
    - keep only graphs with backtracks in [min_backtracks, max_backtracks],
    - use solver's solution as ground-truth coloring (optionally canonicalized).
    Returns: GraphData
    """
    while True:
        n = random.randint(*n_range)
        p = random.uniform(*p_range)
        G = nx.erdos_renyi_graph(n, p)

        if G.number_of_edges() == 0:
            continue

        # Build adjacency
        adj = np.zeros((n, n), dtype=np.float32)
        for u, v in G.edges():
            adj[u, v] = 1.0
            adj[v, u] = 1.0

        # Degree-based node order
        deg = adj.sum(axis=1)
        node_order = np.argsort(-deg)

        # Solve with k_colors to get label + difficulty
        success, assignment, stats = solve_graph_coloring(
            adj, k_colors,
            node_order=node_order,
            init_colors=None,
            color_order=None,
            max_backtracks=label_max_backtracks
        )
        if not success:
            # either not k-colorable or exceeded label_max_backtracks
            continue

        backtracks = stats["backtracks"]
        if backtracks < min_backtracks or backtracks > max_backtracks:
            # too easy or too hard, skip
            continue

        # Labels
        if USE_CANONICAL_COLORS:
            colors = canonicalize_colors(assignment)
        else:
            colors = np.array(assignment, dtype=np.int64)

        # Node features:
        # - degrees
        # - normalized degree
        # - clustering coefficient
        # - 2-hop degree approximation (normalized)
        degrees = np.array([G.degree[i] for i in range(n)], dtype=np.float32)
        max_deg = float(degrees.max()) if n > 0 else 1.0
        if max_deg <= 0.0:
            max_deg = 1.0
        deg_norm = degrees / max_deg

        clust_dict = nx.clustering(G)
        clustering = np.array([clust_dict[i] for i in range(n)], dtype=np.float32)

        # 2-hop degree (approx: count of length-2 walks)
        adj_mat = adj
        two_hop = (adj_mat @ adj_mat).sum(axis=1)  # number of 2-step walks from each node
        max_two_hop = float(two_hop.max()) if n > 0 else 1.0
        if max_two_hop <= 0.0:
            max_two_hop = 1.0
        two_hop_norm = two_hop / max_two_hop

        x = np.stack([deg_norm, clustering, two_hop_norm, degrees / (n + 1e-6)], axis=1)  # [N, 4]

        return GraphData(x, adj, colors, difficulty=backtracks)


def generate_dataset_filtered(num_graphs, n_range, p_range, hard_min, hard_max):
    """
    Generate a dataset of GraphData with hardness filtering.
    """
    data = []
    tries = 0
    max_tries = num_graphs * 50  # generous
    while len(data) < num_graphs and tries < max_tries:
        tries += 1
        g = generate_colored_graph_with_solver(
            k_colors=K_COLORS,
            n_range=n_range,
            p_range=p_range,
            min_backtracks=hard_min,
            max_backtracks=hard_max,
            label_max_backtracks=MAX_LABEL_BACKTRACKS
        )
        data.append(g)
        if len(data) % 100 == 0:
            print(f"Generated {len(data)} / {num_graphs} graphs...")

    if len(data) < num_graphs:
        print(f"[WARNING] only generated {len(data)} graphs out of requested {num_graphs}")

    random.shuffle(data)
    return data


def build_datasets():
    """
    Generate and split the dataset into train / val / test for chosen regime.
    """
    cfg = REGIMES[REGIME]
    print(f"\n=== Building dataset for regime: {REGIME} ===")
    print(f"n_range = {cfg['n_range']}, p_range = {cfg['p_range']}, "
          f"hard_min = {cfg['hard_min']}, hard_max = {cfg['hard_max']}")

    all_graphs = generate_dataset_filtered(
        NUM_TOTAL,
        n_range=cfg["n_range"],
        p_range=cfg["p_range"],
        hard_min=cfg["hard_min"],
        hard_max=cfg["hard_max"]
    )
    print("Total graphs generated:", len(all_graphs))

    train_graphs = all_graphs[:NUM_TRAIN]
    val_graphs   = all_graphs[NUM_TRAIN:NUM_TRAIN + NUM_VAL]
    test_graphs  = all_graphs[NUM_TRAIN + NUM_VAL:NUM_TOTAL]

    print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")

    # Quick sanity: show some sizes & difficulties
    print("Example train graph sizes:",
          [g.num_nodes for g in train_graphs[:10]])
    print("Example train graph difficulties (backtracks):",
          [g.difficulty for g in train_graphs[:10]])

    d_train = np.array([g.difficulty for g in train_graphs])
    print("Train difficulty stats (backtracks):",
          "min =", d_train.min(),
          "median =", np.median(d_train),
          "max =", d_train.max())

    # Label distribution sanity check
    all_y = np.concatenate([g.y.numpy() for g in train_graphs])
    hist = np.bincount(all_y, minlength=K_COLORS)
    print("Global train label distribution (freq):", hist)
    print("Global train label distribution (proportions):", hist / hist.sum())

    return train_graphs, val_graphs, test_graphs


# ============================================================
# 3. Models: GCN, Residual GCN, MLP
# ============================================================

class GCNLayer(nn.Module):
    """
    One GCN layer using symmetric normalization:
    H' = D^{-1/2} (A + I) D^{-1/2} H W
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x, adj):
        """
        x: [N, F]
        adj: [N, N] (0/1)
        """
        N = adj.size(0)
        I = torch.eye(N, device=adj.device)
        A_hat = adj + I
        deg = A_hat.sum(dim=1)              # [N]
        D_inv_sqrt = deg.pow(-0.5)
        D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
        norm_adj = D_inv_sqrt.unsqueeze(1) * A_hat * D_inv_sqrt.unsqueeze(0)  # [N, N]
        h = torch.matmul(norm_adj, x)      # [N, F]
        return self.linear(h)              # [N, out_dim]


class GCNColoring(nn.Module):
    """
    Multi-layer GCN with dropout for node-wise color prediction.
    """
    def __init__(self, in_dim, hidden_dim, num_colors, num_layers=4, dropout=0.05):
        super().__init__()
        assert num_layers >= 2
        layers = []
        layers.append(GCNLayer(in_dim, hidden_dim))
        for _ in range(num_layers - 2):
            layers.append(GCNLayer(hidden_dim, hidden_dim))
        layers.append(GCNLayer(hidden_dim, hidden_dim))
        self.layers = nn.ModuleList(layers)
        self.out_linear = nn.Linear(hidden_dim, num_colors)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, adj):
        h = x
        for layer in self.layers:
            h = layer(h, adj)
            h = F.relu(h)
            h = self.dropout(h)
        logits = self.out_linear(h)  # [N, num_colors]
        return logits


class ResGCNBlock(nn.Module):
    """
    Residual GCN block: GCNLayer + residual + LayerNorm.
    """
    def __init__(self, dim):
        super().__init__()
        self.gcn = GCNLayer(dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, adj):
        out = self.gcn(x, adj)
        out = out + x
        out = F.relu(out)
        return self.norm(out)


class ResGCNColoring(nn.Module):
    """
    Stronger residual GCN:
    - input projection
    - several ResGCNBlock layers
    - output layer
    """
    def __init__(self, in_dim, hidden_dim, num_colors, num_blocks=5, dropout=0.1):
        super().__init__()
        self.in_linear = nn.Linear(in_dim, hidden_dim)
        self.blocks = nn.ModuleList([ResGCNBlock(hidden_dim) for _ in range(num_blocks)])
        self.dropout = nn.Dropout(dropout)
        self.out_linear = nn.Linear(hidden_dim, num_colors)

    def forward(self, x, adj):
        h = F.relu(self.in_linear(x))
        for block in self.blocks:
            h = block(h, adj)
            h = self.dropout(h)
        logits = self.out_linear(h)
        return logits


class MLPColoring(nn.Module):
    """
    Optional baseline: ignores adj, uses only node features.
    """
    def __init__(self, in_dim, hidden_dim, num_colors, num_layers=3, dropout=0.1):
        super().__init__()
        dims = [in_dim] + [hidden_dim] * (num_layers - 1)
        layers = []
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(d_in, d_out))
        self.layers = nn.ModuleList(layers)
        self.out_linear = nn.Linear(hidden_dim, num_colors)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, adj):
        h = x
        for layer in self.layers:
            h = F.relu(layer(h))
            h = self.dropout(h)
        return self.out_linear(h)


# ============================================================
# 4. GNN evaluation helpers (GNN-only metrics)
# ============================================================

def evaluate_gnn(model, graphs, device=DEVICE, k_colors=K_COLORS):
    """
    Evaluate GNN alone (no solver):
    - node accuracy
    - graph-level exact accuracy
    - edge violation rate
    - valid-coloring rate (no same-color edge)
    """
    model.eval()
    total_nodes = 0
    correct_nodes = 0
    total_graphs = len(graphs)
    exact_graphs = 0

    total_edges = 0
    total_violations = 0
    valid_colorings = 0

    with torch.no_grad():
        for g in graphs:
            x = g.x.to(device)
            adj = g.adj.to(device)
            y = g.y.to(device)

            logits = model(x, adj)
            pred = logits.argmax(dim=1)           # [N]

            correct = (pred == y).sum().item()
            total_nodes += y.size(0)
            correct_nodes += correct
            if correct == y.size(0):
                exact_graphs += 1

            edges = (g.adj > 0.5).nonzero(as_tuple=False)
            if edges.numel() > 0:
                u = edges[:, 0]
                v = edges[:, 1]
                mask = u < v
                u = u[mask]
                v = v[mask]
                if u.numel() > 0:
                    same_color = (pred[u] == pred[v]).float()
                    total_violations += same_color.sum().item()
                    total_edges += u.numel()

                    if same_color.sum().item() == 0:
                        valid_colorings += 1
                else:
                    valid_colorings += 1
            else:
                valid_colorings += 1

    node_acc = correct_nodes / total_nodes
    exact_rate = exact_graphs / total_graphs
    viol_rate = (total_violations / total_edges) if total_edges > 0 else 0.0
    valid_rate = valid_colorings / total_graphs

    return {
        "node_acc": node_acc,
        "graph_exact": exact_rate,
        "edge_violation": viol_rate,
        "gnn_valid_colorings": valid_rate,
    }


# ============================================================
# 5. Training loop with early stopping
# ============================================================

def train_gnn(
    model,
    train_graphs,
    val_graphs,
    use_violation_loss=USE_VIOLATION_LOSS,
    lambda_viol=LAMBDA_VIOL,
    epochs=EPOCHS,
    patience=PATIENCE,
    device=DEVICE
):
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=LR,
        weight_decay=WEIGHT_DECAY
    )

    print(model)

    best_val_acc = 0.0
    best_state = None
    no_improve = 0

    for epoch in range(1, epochs + 1):
        model.train()
        random.shuffle(train_graphs)
        total_loss = 0.0
        total_nodes = 0

        for g in train_graphs:
            x = g.x.to(device)
            adj = g.adj.to(device)
            y = g.y.to(device)

            logits = model(x, adj)
            ce_loss = F.cross_entropy(logits, y)

            if use_violation_loss:
                with torch.no_grad():
                    edges = (adj > 0.5).nonzero(as_tuple=False)
                if edges.numel() > 0:
                    prob = F.softmax(logits, dim=1)
                    u = edges[:, 0]
                    v = edges[:, 1]
                    mask = u < v
                    u = u[mask]
                    v = v[mask]
                    if u.numel() > 0:
                        same_prob = (prob[u] * prob[v]).sum(dim=1)
                        loss_viol = same_prob.mean()
                        loss = ce_loss + lambda_viol * loss_viol
                    else:
                        loss = ce_loss
                else:
                    loss = ce_loss
            else:
                loss = ce_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()

            total_loss += loss.item() * g.num_nodes
            total_nodes += g.num_nodes

        train_loss = total_loss / total_nodes
        val_metrics = evaluate_gnn(model, val_graphs, device, K_COLORS)
        val_acc = val_metrics["node_acc"]
        val_viol = val_metrics["edge_violation"]
        val_valid = val_metrics["gnn_valid_colorings"]

        print(
            f"Epoch {epoch:03d} | train_loss = {train_loss:.4f} | "
            f"val_node_acc = {val_acc:.3f} | "
            f"val_edge_violation = {val_viol:.3f} | "
            f"val_gnn_valid = {val_valid:.3f}"
        )

        # Early stopping on validation node accuracy
        if val_acc > best_val_acc + 1e-4:
            best_val_acc = val_acc
            best_state = copy.deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch}, best val_node_acc = {best_val_acc:.3f}")
                break

    # Load best model
    if best_state is not None:
        model.load_state_dict(best_state)
        print("Loaded best validation checkpoint.")

    return model


# ============================================================
# 6. Solver evaluation with/without GNN guidance
# ============================================================

def eval_solvers_on_dataset(
    graphs,
    model,
    device=DEVICE,
    k_colors=K_COLORS,
    max_backtracks=5000
):
    """
    Compare strategies:
    - 'random': random node order, no GNN
    - 'degree': high-degree-first node order, no GNN
    - 'gnn_ordering': order nodes by GNN confidence, use GNN color ranking
    - 'gnn_warm_start': degree order + GNN color ranking
    - 'gnn_both': GNN ordering + GNN color ranking
    """
    model.eval()

    methods = ["random", "degree", "gnn_ordering", "gnn_warm_start", "gnn_both"]
    stats = {
        m: {
            "success": [],
            "backtracks": [],
            "steps": [],
            "runtime": [],
        } for m in methods
    }

    with torch.no_grad():
        for g in graphs:
            adj_np = g.adj.cpu().numpy()
            n = g.num_nodes
            deg = adj_np.sum(axis=1)

            x = g.x.to(device)
            adj = g.adj.to(device)
            logits = model(x, adj)
            prob = F.softmax(logits, dim=1).cpu().numpy()   # [N, K]
            gnn_colors = prob.argmax(axis=1)
            gnn_confidence = prob.max(axis=1)
            color_order = np.argsort(-prob, axis=1)         # [N, K]

            for method in methods:
                if method == "random":
                    node_order = np.random.permutation(n)
                    init_colors = None
                    method_color_order = None

                elif method == "degree":
                    node_order = np.argsort(-deg)
                    init_colors = None
                    method_color_order = None

                elif method == "gnn_ordering":
                    node_order = np.argsort(-gnn_confidence)
                    init_colors = None
                    method_color_order = color_order

                elif method == "gnn_warm_start":
                    node_order = np.argsort(-deg)
                    init_colors = None
                    method_color_order = color_order

                elif method == "gnn_both":
                    node_order = np.argsort(-gnn_confidence)
                    init_colors = None
                    method_color_order = color_order

                else:
                    continue

                t0 = time.time()
                success, assignment, s = solve_graph_coloring(
                    adj_np, k_colors,
                    node_order=node_order,
                    init_colors=init_colors,
                    color_order=method_color_order,
                    max_backtracks=max_backtracks
                )
                elapsed = time.time() - t0

                if success:
                    valid = check_valid_coloring(adj_np, assignment, k_colors)
                    if not valid:
                        print("[WARNING] invalid coloring despite success for method", method)

                stats[method]["success"].append(1 if success else 0)
                stats[method]["backtracks"].append(s["backtracks"])
                stats[method]["steps"].append(s["steps"])
                stats[method]["runtime"].append(elapsed)

    summary = {}
    for method in methods:
        s = stats[method]
        success_rate = np.mean(s["success"])
        avg_backtracks = np.mean(s["backtracks"])
        avg_steps = np.mean(s["steps"])
        avg_runtime = np.mean(s["runtime"])

        summary[method] = {
            "success_rate": success_rate,
            "avg_backtracks": avg_backtracks,
            "avg_steps": avg_steps,
            "avg_runtime_sec": avg_runtime,
        }

    return summary, stats


# ============================================================
# 7. Debug: Overfit a single graph (sanity check)
# ============================================================

def debug_overfit_single_graph(graph):
    """
    Optional: sanity check that the model *can* learn at all by overfitting 1 graph.
    """
    print("\n[DEBUG] Overfitting a single graph...")
    if MODEL_TYPE == "gcn":
        model = GCNColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_layers=NUM_LAYERS,
            dropout=0.0
        ).to(DEVICE)
    elif MODEL_TYPE == "resgcn":
        model = ResGCNColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_blocks=5,
            dropout=0.0
        ).to(DEVICE)
    else:
        model = MLPColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_layers=3,
            dropout=0.0
        ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)

    for epoch in range(1, 501):
        model.train()
        x = graph.x.to(DEVICE)
        adj = graph.adj.to(DEVICE)
        y = graph.y.to(DEVICE)

        logits = model(x, adj)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred = logits.argmax(dim=1)
        acc = (pred == y).float().mean().item()

        if epoch % 50 == 0 or acc == 1.0:
            print(f"[Single-graph] Epoch {epoch}, loss={loss.item():.4f}, acc={acc:.3f}")
        if acc == 1.0:
            print("[Single-graph] Reached 100% training accuracy.")
            break


# ============================================================
# 8. Main script
# ============================================================

def main():
    # 1) Build datasets
    train_graphs, val_graphs, test_graphs = build_datasets()

    if RUN_SINGLE_GRAPH_DEBUG:
        debug_overfit_single_graph(train_graphs[0])

    # 2) Create model (ResGCN / GCN / MLP)
    if MODEL_TYPE == "gcn":
        model = GCNColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_layers=NUM_LAYERS,
            dropout=0.05
        ).to(DEVICE)
    elif MODEL_TYPE == "resgcn":
        model = ResGCNColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_blocks=5,
            dropout=0.1
        ).to(DEVICE)
    elif MODEL_TYPE == "mlp":
        model = MLPColoring(
            in_dim=NODE_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            num_colors=K_COLORS,
            num_layers=3,
            dropout=0.1
        ).to(DEVICE)
    else:
        raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}")

    # 3) Train with early stopping
    model = train_gnn(
        model,
        train_graphs,
        val_graphs,
        use_violation_loss=USE_VIOLATION_LOSS,
        lambda_viol=LAMBDA_VIOL,
        epochs=EPOCHS,
        patience=PATIENCE,
        device=DEVICE
    )

    # 4) GNN-only test metrics
    test_metrics = evaluate_gnn(model, test_graphs, DEVICE, K_COLORS)
    print("\n=== GNN-only test metrics ===")
    for k, v in test_metrics.items():
        print(f"{k}: {v:.4f}")

    # 5) Solver performance with/without GNN
    solver_summary, solver_raw = eval_solvers_on_dataset(
        test_graphs,
        model,
        DEVICE,
        K_COLORS,
        max_backtracks=5000
    )

    print("\n=== Solver performance on test set (with/without GNN) ===")
    for method, s in solver_summary.items():
        print(f"\nMethod: {method}")
        for key, val in s.items():
            if "time" in key:
                print(f"  {key}: {val:.6f}")
            else:
                print(f"  {key}: {val:.3f}")


if __name__ == "__main__":
    main()


Using device: cuda

=== Building dataset for regime: easy ===
n_range = (10, 18), p_range = (0.2, 0.45), hard_min = 5, hard_max = 150
Generated 100 / 1400 graphs...
Generated 200 / 1400 graphs...
Generated 300 / 1400 graphs...
Generated 400 / 1400 graphs...
Generated 500 / 1400 graphs...
Generated 600 / 1400 graphs...
Generated 700 / 1400 graphs...
Generated 800 / 1400 graphs...
Generated 900 / 1400 graphs...
Generated 1000 / 1400 graphs...
Generated 1100 / 1400 graphs...
Generated 1200 / 1400 graphs...
Generated 1300 / 1400 graphs...
Generated 1400 / 1400 graphs...
Total graphs generated: 1400
Train: 1000, Val: 200, Test: 200
Example train graph sizes: [16, 18, 18, 16, 16, 17, 13, 18, 17, 18]
Example train graph difficulties (backtracks): [75, 58, 9, 13, 28, 8, 11, 7, 13, 133]
Train difficulty stats (backtracks): min = 5 median = 15.0 max = 148
Global train label distribution (freq): [4006 4203 4149 3570]
Global train label distribution (proportions): [0.25150678 0.26387494 0.26048468