In [None]:
!pip install torch_geometric



Initial baseline update & reward

*   Read current labels from env_work.FrankenData.dist_label.
*   Convert labels to an action matrix with labels_to_action.

Call env_work.step(action_curr). This runs env.step


The returned reward_curr becomes the baseline reward for Metropolis comparisons. This also ensures env_work reflects the post-step opinions (so proposals are compared against a consistent oldX).

In [None]:
import numpy as np
import torch
from torch_geometric.data import Data
import gymnasium as gym
from gymnasium import spaces
import gerry_environment_13
from copy import deepcopy
from typing import Optional, Tuple, Dict, Any
import math

In [None]:
def labels_to_action(labels, num_districts, dtype=np.float32):
    """
    Convert an integer label vector (shape [N]) to an action matrix
    expected by env.step: shape (N, num_districts), each row is
    a 1-hot encoding of the desired district for that voter.
    """
    N = len(labels)
    A = np.zeros((N, num_districts), dtype=dtype)
    for i, lab in enumerate(labels):
        if lab >= 0 and lab < num_districts:
            A[i, int(lab)] = 1.0
        else:
            # keep row zeros -> will become -1 label in env.step (avoid if possible)
            pass
    return A

In [None]:

    # Propose a new district assignment by flipping a connected cluster
    # between two neighboring districts.
    # labels : current district labels (numpy int array)
    # geo_edge : (2,E) numpy array of undirected geographic edges
    # rng : numpy.random.Generator
    # Returns: new label array (copy of labels with a cluster flipped)

def propose_flip(labels, geo_edge, rng):

    N = len(labels)
    # Step 1: pick a random edge connecting *different* districts
    # only consider edges where endpoints are in different districts
    diff_edges = [(u,v) for u,v in geo_edge.T if labels[u] != labels[v]]
    if not diff_edges:
        return labels.copy()   # nothing to flip if all borders gone
    u,v = diff_edges[rng.integers(0, len(diff_edges))]
    dA, dB = labels[u], labels[v]

    # Step 2: build induced subgraph of nodes in {dA,dB}
    # adjacency list
    adj = [[] for _ in range(N)]
    for x,y in geo_edge.T:
        if labels[x] in (dA,dB) and labels[y] in (dA,dB):
            adj[x].append(y)
            adj[y].append(x)

    # Step 3: BFS/DFS to find connected component containing u
    cluster = set()
    stack = [u]
    while stack:
        node = stack.pop()
        if node in cluster:
            continue
        cluster.add(node)
        for nei in adj[node]:
            if nei not in cluster:
                stack.append(nei)

    # Step 4: flip the entire cluster to the opposite district
    new_labels = labels.copy()
    target = dB if rng.random() < 0.5 else dA
    for node in cluster:
        new_labels[node] = target
    return new_labels


# **Main SA loop (hot → anneal → cold)**
For every SA iteration:

**1- Proposal generation (propose_flip):**
Chooses a border edge, builds a two-district cluster, flips it.
This is the proposal step.

**2- Pre-checks:**

Non-empty district check: all(np.any(y_prop == d) for d in ...). If a district would become empty, reject immediately.

Contiguity check: check_contiguity(geo_edge, y_prop, num_districts). If non-contiguous, reject immediately. This prevents expensive simulation for obviously invalid proposals.

**3- Evaluation (simulation):**

Convert y_prop into an action matrix with labels_to_action.

Make tmp_env = deepcopy(env_work) and call tmp_env.step(action_prop). That runs step and reward logic to produce reward_prop and tmp_env.FrankenData (the hypothetical post-proposal state).

This tmp_env.step is where the proposal is actually evaluated (it runs opinion update, elects reps, computes augmented graph and reward).

**4- Acceptance (Metropolis-style):**

Compute diff = reward_prop - reward_curr.
Acceptance probability p_accept = min(1, exp(diff / T)), where T is the current temperature (annealed linearly from T_init → T_final).

With probability p_accept accept the proposal: commit by setting env_work = tmp_env and reward_curr = reward_prop.

If rejected, env_work remains unchanged.

In [None]:
def simulated_annealing_with_env(
    env,
    hot_steps=500,
    anneal_steps=1800,
    cold_steps=100,
    T_init=1.0,
    T_final=0.01,
    seed=None,
):

    rng = np.random.default_rng(seed)
    env_work = deepcopy(env)   # work on a copy so original env remains unchanged

    # --- initialize: run one "no-op" step to get current opinions updated and a baseline reward ---
    curr_labels = env_work.FrankenData.dist_label.clone().long().numpy().astype(int)

    # ensure no -1 labels exist for initialization (if there are, you'll need to handle them upstream)
    if np.any(curr_labels < 0):
        raise ValueError("Current FrankenData.dist_label contains -1 (unassigned). Assign before SA.")

    # Do a single step with the current labels to produce current opinion update & reward baseline
    action_curr = labels_to_action(curr_labels, env_work.num_districts)
    _, reward_curr, _, _, _ = env_work.step(action_curr)   # env_work mutated here to the "post-step" state

    ensemble_labels = [env_work.FrankenData.dist_label.clone().long().numpy().astype(int).copy()]
    ensemble_rewards = [float(reward_curr)]

    total_steps = hot_steps + anneal_steps + cold_steps
    step_idx = 0



    # Geo-edge in numpy form for contiguity checks
    geo_edge = env_work.FrankenData.geographical_edge.clone().numpy()

    # MAIN loop: hot -> anneal -> cold
    for phase, n_steps in (("hot", hot_steps), ("anneal", anneal_steps), ("cold", cold_steps)):
        for _ in range(n_steps):
            # T = temperature(step_idx)
            T = max(T_final, T_init + (T_final - T_init) * (step_idx / max(1, total_steps - 1)))
            step_idx += 1

            # propose
            curr_labels = env_work.FrankenData.dist_label.clone().long().numpy().astype(int)
            y_prop = propose_flip(env_work, curr_labels)

            # quick pre-checks
            # - non-empty districts
            if not all(np.any(y_prop == d) for d in range(env_work.num_districts)):
                # reject immediately (keeps current plan)
                ensemble_labels.append(curr_labels.copy())
                ensemble_rewards.append(float(reward_curr))
                continue

            # - contiguity check: use  check_contiguity (fast pre-filter)
            if not env_work.check_contiguity(geo_edge, y_prop, env_work.num_districts):
                # reject immediately
                ensemble_labels.append(curr_labels.copy())
                ensemble_rewards.append(float(reward_curr))
                continue

            # Build action for proposal (one-hot rows)
            action_prop = labels_to_action(y_prop, env_work.num_districts)

            # Evaluate proposal by simulating one step on a deepcopy of current env
            tmp_env = deepcopy(env_work)
            _, reward_prop, _, _, _ = tmp_env.step(action_prop)   # uses  env.step / env.reward logic

            # Metropolis acceptance on rewards with temperature T:
            # p = min(1, exp((reward_prop - reward_curr)/T))
            # note: rewards can be negative; using difference / T is standard for SA (treat reward as -energy)
            diff = float(reward_prop) - float(reward_curr)
            # guard against overflow
            try:
                p_accept = math.exp(diff / max(1e-8, T))
                p_accept = min(1.0, p_accept)
            except OverflowError:
                p_accept = 1.0 if diff > 0 else 0.0

            if rng.random() < p_accept:
                # accept: commit tmp_env as the new working environment
                env_work = tmp_env
                reward_curr = float(reward_prop)
                ensemble_labels.append(env_work.FrankenData.dist_label.clone().long().numpy().astype(int).copy())
                ensemble_rewards.append(reward_curr)
            else:
                # reject: keep current env
                ensemble_labels.append(env_work.FrankenData.dist_label.clone().long().numpy().astype(int).copy())
                ensemble_rewards.append(float(reward_curr))
                env_work.reset()

    return env_work, ensemble_labels, ensemble_rewards
