In [1]:
pip install -U qiskit-optimization

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
from dataclasses import dataclass
import math
from typing import List, Tuple, Dict, Iterable
import random
import itertools
from typing import Optional
from collections import defaultdict

In [3]:
@dataclass(frozen=True)
class BCCDirection:
    dx: int; dy: int; dz: int
    def as_tuple(self): return (self.dx, self.dy, self.dz)
    def norm(self): return math.sqrt(self.dx**2 + self.dy**2 + self.dz**2)
    def angle_deg(self, other): 
        c = max(-1,min(1,(self.dx*other.dx+self.dy*other.dy+self.dz*other.dz)/(self.norm()*other.norm())))
        return math.degrees(math.acos(c))

class BCCLattice:
    DIRS = [BCCDirection(x,y,z) for x in (1,-1) for y in (1,-1) for z in (1,-1)]
    CODE_to_IDX = {(0,0,0):0,(0,0,1):1,(0,1,0):2,(0,1,1):3,
                   (1,0,0):4,(1,0,1):5,(1,1,0):6,(1,1,1):7}
    IDX_to_CODE = {v:k for k,v in CODE_to_IDX.items()}

    def step_angle_deg(self,i,j): return self.DIRS[i].angle_deg(self.DIRS[j])
    def build_walk(self, start=(0,0,0), idxs:List[int]=[]):
        x,y,z=start; pts=[start]
        for i in idxs:
            d=self.DIRS[i]; x+=d.dx; y+=d.dy; z+=d.dz; pts.append((x,y,z))
        return pts
    @staticmethod
    def self_avoiding(path): return len(set(path))==len(path)
    @staticmethod
    def no_backtracking(idxs): return all(idxs[i]^7!=idxs[i+1] for i in range(len(idxs)-1))
    def successive_angles(self, idxs): return [self.step_angle_deg(idxs[i],idxs[i+1]) for i in range(len(idxs)-1)]

In [None]:
l = BCCLattice()
angles = sorted({ round(l.step_angle_deg(i,j),2)
                  for i in range(8) for j in range(8)
                  if i!=j and (i^7)!=j })
print("Distinct non-backtracking angles:", angles)
dirs = [0,2,1,4,6]
path = l.build_walk((0,0,0), dirs)
print("Path:", path)
# print("Self-avoiding?", l.self_avoiding(path))
# print("No backtracking?", l.no_backtracking(dirs))
print("Angles:", [round(a,2) for a in l.successive_angles(dirs)])


Distinct non-backtracking angles: [70.53, 109.47]
Path: [(0, 0, 0), (1, 1, 1), (2, 0, 2), (3, 1, 1), (2, 2, 2), (1, 1, 3)]
Self-avoiding? True
No backtracking? True
Angles: [70.53, 109.47, 109.47, 70.53]


In [None]:
from qiskit import QuantumCircuit, QuantumRegister

def bcc_registers(chain_length:int):
    n_bonds = chain_length - 1
    qc = QuantumCircuit()
    regs = []
    for b in range(n_bonds):
        qreg = QuantumRegister(3, name=f"b{b}")  
        qc.add_register(qreg)
        regs.append(qreg)
    return qc, regs

qc, regs = bcc_registers(6)
print([r.name for r in regs])

['b0', 'b1', 'b2', 'b3', 'b4']


In [None]:
from typing import List, Tuple, Optional, Dict
import random

HP_MAP: Dict[str, str] = {  # Hydrophobic/Polar toy mapping for short sequences
    **{a:'H' for a in list('AVLIMFWYCGP')},  #  hydrophobic??????
    **{a:'P' for a in list('STNQDEKRH')}     # polar/charged????
}

def tokenize_sequence(seq: str) -> List[str]:
    seq = ''.join([c for c in seq.upper() if c.isalpha()])
    return list(seq)

class TurnEncoding:
    def __init__(self, lattice: BCCLattice, seq: List[str]):
        self.lattice = lattice
        self.seq = seq
        self.n = len(seq)
        self.n_bonds = self.n - 1
        if self.n_bonds <= 0:
            raise ValueError("Sequence must have length >= 2.")

        self.dir_indices: Optional[List[int]] = None         
        self.dir_bits: Optional[List[Tuple[int,int,int]]] = None 
        self.qc = None
        self.regs = None

    def set_from_indices(self, dir_indices: List[int]):
        if len(dir_indices) != self.n_bonds:
            raise ValueError("Length of dir_indices must be N-1.")
        if not self.lattice.no_backtracking(dir_indices):
            raise ValueError("Backtracking detected.")
        self.dir_indices = dir_indices
        self.dir_bits = [self.lattice.IDX_to_CODE[i] for i in dir_indices]

    def set_from_bits(self, dir_bits: List[Tuple[int,int,int]]):
        if len(dir_bits) != self.n_bonds:
            raise ValueError("Length of dir_bits must be N-1.")
        idxs = [self.lattice.CODE_to_IDX[tuple(int(b)&1 for b in bits)] for bits in dir_bits]
        self.set_from_indices(idxs)

    def coordinates(self, start=(0,0,0)) -> List[Tuple[int,int,int]]:
        if self.dir_indices is None:
            raise RuntimeError("No classical directions set. Use set_from_indices/bits or sample_random_fold().")
        path = self.lattice.build_walk(start, self.dir_indices)
        return path

    def allocate_qiskit(self):
        try:
            from qiskit import QuantumCircuit, QuantumRegister
        except Exception as e:
            raise RuntimeError("Qiskit not available in this env.") from e
        qc = QuantumCircuit()
        regs = []
        for b in range(self.n_bonds):
            qreg = QuantumRegister(3, name=f"b{b}")  
            qc.add_register(qreg)
            regs.append(qreg)
        self.qc, self.regs = qc, regs
        return qc, regs

    def sample_random_fold(self, max_restarts=400):
        for _ in range(max_restarts):
            path = [(0,0,0)]
            idxs = []
            used = {path[0]}
            ok = True
            for t in range(self.n_bonds):
                candidates = list(range(8))
                if idxs:
                    candidates = [j for j in candidates if j != (idxs[-1] ^ 7)]
                random.shuffle(candidates)
                moved = False
                for j in candidates:
                    dx,dy,dz = self.lattice.DIRS[j].as_tuple()
                    nxt = (path[-1][0]+dx, path[-1][1]+dy, path[-1][2]+dz)
                    if nxt not in used:
                        path.append(nxt)
                        used.add(nxt)
                        idxs.append(j)
                        moved = True
                        break
                if not moved:
                    ok = False
                    break
            if ok and len(idxs) == self.n_bonds:
                self.set_from_indices(idxs)
                return idxs, path
        raise RuntimeError("Err")


In [None]:
class CoordinateEncoding:
    def __init__(self, lattice: BCCLattice, seq: List[str]):
        self.lattice = lattice
        self.seq = seq
        self.n = len(seq)
        self.coords: Optional[List[Tuple[int,int,int]]] = None

    def from_turn_encoding(self, turn: TurnEncoding, start=(0,0,0)):
        pts = turn.coordinates(start)
        if len(pts) != self.n:
            raise ValueError("Turn encoding err.")
        self.coords = pts
        return pts

    def is_self_avoiding(self) -> bool:
        if self.coords is None:
            raise RuntimeError("Coordinates missin")
        return BCCLattice.self_avoiding(self.coords)

    def contacts_HP(self, seq: List[str]) -> int:
        if self.coords is None:
            raise RuntimeError("No coordinates set.")
        positions = {i:self.coords[i] for i in range(self.n)}
        H_indices = [i for i,a in enumerate(seq) if HP_MAP.get(a,'P')=='H']
        def manhattan(p,q): return sum(abs(pi-qi) for pi,qi in zip(p,q))
        contacts = 0
        for i in range(self.n):
            if i not in H_indices: continue
            for j in range(i+2, self.n):  # non-consecutive
                if j not in H_indices: continue
                if manhattan(positions[i], positions[j]) in (2,3):
                    contacts += 1
        return contacts


In [None]:
sequence = "ACDEFGHIK"   
seq = tokenize_sequence(sequence)

lattice = BCCLattice()
turn = TurnEncoding(lattice, seq)
dir_indices, path = turn.sample_random_fold()
print("Turn-based dir indices:", dir_indices)
print("Turn-based bits:      ", [lattice.IDX_to_CODE[i] for i in dir_indices])
angles = [round(a,2) for a in lattice.successive_angles(dir_indices)]
print("Angles (deg):         ", angles)
coord = CoordinateEncoding(lattice, seq)
coords = coord.from_turn_encoding(turn)
print("Coordinates:")
for i,(aa,xyz) in enumerate(zip(seq, coords)):
    print(f"  {i:2d} {aa} -> {xyz}")
try:
    qc, regs = turn.allocate_qiskit()
    print("Qiskit registers:", [r.name for r in regs], "(3 qubits each)")
except RuntimeError as e:
    print("Err")


Turn-based dir indices: [1, 0, 6, 3, 3, 3, 6, 0]
Turn-based bits:       [(0, 0, 1), (0, 0, 0), (1, 1, 0), (0, 1, 1), (0, 1, 1), (0, 1, 1), (1, 1, 0), (0, 0, 0)]
Angles (deg):          [70.53, 109.47, 109.47, 0.0, 0.0, 109.47, 109.47]
Coordinates:
   0 A -> (0, 0, 0)
   1 C -> (1, 1, -1)
   2 D -> (2, 2, 0)
   3 E -> (1, 1, 1)
   4 F -> (2, 0, 0)
   5 G -> (3, -1, -1)
   6 H -> (4, -2, -2)
   7 I -> (3, -3, -1)
   8 K -> (4, -2, 0)
Self-avoiding?  True
Toy H–H contacts: 3
Qiskit registers: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7'] (3 qubits each)


In [None]:
from math import isclose
lattice = BCCLattice()
dir_indices = [6,4,2,0,6,6,6,5]
reported_bits = [(1,1,0),(1,0,0),(0,1,0),(0,0,0),(1,1,0),(1,1,0),(1,1,0),(1,0,1)]
reported_coords = [(0,0,0),(-1,-1,1),(-2,0,2),(-1,-1,3),(0,0,4),(-1,-1,5),(-2,-2,6),(-3,-3,7),(-4,-2,6)]
derived_bits = [lattice.IDX_to_CODE[i] for i in dir_indices]
assert derived_bits == reported_bits, f"Bit mismatch:\n got {derived_bits}\n exp {reported_bits}"
recalc_coords = lattice.build_walk((0,0,0), dir_indices)
assert recalc_coords == reported_coords, f"Coord mismatch:\n got {recalc_coords}\n exp {reported_coords}"
assert BCCLattice.self_avoiding(reported_coords), "Path is not self-avoiding"
assert BCCLattice.no_backtracking(dir_indices), "Backtracking detected"
angles = lattice.successive_angles(dir_indices)
zero_angle_positions = [k for k,a in enumerate(angles) if isclose(a,0.0,abs_tol=1e-9)]

print("All checks passed")
print("Angles (deg):", [round(a,2) for a in angles])
print("Zero-angle (straight) steps between bonds:", zero_angle_positions,
      "(these are where direction[i] == direction[i+1])")
print("Registers expected:", [f"b{k}" for k in range(len(dir_indices))], "(3 qubits each)")


All checks passed
Angles (deg): [70.53, 109.47, 70.53, 109.47, 0.0, 0.0, 109.47]
Zero-angle (straight) steps between bonds: [4, 5] (these are where direction[i] == direction[i+1])
Registers expected: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7'] (3 qubits each)


# Hamiltonian

In [None]:
from typing import List, Tuple, Dict, Optional
import itertools
def bcc_neighbors(p: Tuple[int,int,int]) -> List[Tuple[int,int,int]]:
    x,y,z = p
    return [(x+dx,y+dy,z+dz) for dx,dy,dz in
            [(+1,+1,+1),(+1,+1,-1),(+1,-1,+1),(+1,-1,-1),
             (-1,+1,+1),(-1,+1,-1),(-1,-1,+1),(-1,-1,-1)]]

def reachable_bcc_positions(n_steps: int) -> List[Tuple[int,int,int]]:
    S = {(0,0,0)}
    frontier = {(0,0,0)}
    for _ in range(n_steps):
        newF = set()
        for p in frontier:
            newF.update(bcc_neighbors(p))
        frontier = newF - S
        S |= newF
    return sorted(S)

def hp_energy(a: str, b: str) -> float:
    """HP contact energy: only H–H = -1.0; others 0. (Common mapping)"""
    H = set("AVLIMFWYCGP")
    return -1.0 if (a in H and b in H) else 0.0

def load_mj96_table3() -> Dict[Tuple[str,str], float]:
    """
    Miyazawa–Jernigan (1996) Table 3 contact energies e_ij (RT units).
    Negative = favorable. Symmetric (e_ij == e_ji). Diagonal included.
    Source: MJ96 Table 3. (Embed a commonly redistributed copy of the table.)
    """
    # Amino-acid order used below
    AA = ["C","M","F","I","L","V","W","Y","A","G","T","S","N","Q","D","E","H","R","K","P"]
    # Matrix values (upper triangle with diagonal), RT units.
    # These numbers correspond to MJ96 Table 3 (same figures reported in many redistributions).
    # Row i, col j with j>=i:
    mj_upper = [
    # C
    [-5.44, -4.99, -5.80, -5.50, -5.83, -4.96, -4.95, -4.16, -3.57, -3.16, -3.11, -2.86, -2.59, -2.85, -2.41, -2.27, -3.60, -2.57, -1.95, -3.07],
    # M
    [       -5.46, -6.56, -6.02, -6.41, -5.32, -5.55, -4.91, -3.94, -3.39, -3.51, -3.03, -2.95, -3.30, -2.57, -2.89, -3.98, -3.12, -2.48, -3.45],
    # F
    [              -7.26, -6.84, -7.28, -6.29, -6.16, -5.66, -4.81, -4.13, -4.28, -4.02, -3.75, -4.10, -3.48, -3.56, -4.77, -3.98, -3.36, -4.25],
    # I
    [                     -6.54, -7.04, -6.05, -5.78, -5.25, -4.58, -3.78, -4.03, -3.52, -3.24, -3.67, -3.17, -3.27, -4.14, -3.63, -3.01, -3.76],
    # L
    [                            -7.37, -6.48, -6.14, -5.67, -4.91, -4.16, -4.34, -3.92, -3.74, -4.04, -3.40, -3.59, -4.54, -4.03, -3.37, -4.20],
    # V
    [                                   -5.52, -5.18, -4.62, -4.04, -3.38, -3.46, -3.05, -2.83, -3.07, -2.48, -2.67, -3.58, -3.07, -2.49, -3.32],
    # W
    [                                          -5.06, -4.66, -3.82, -3.42, -3.22, -2.99, -3.07, -3.11, -2.84, -2.99, -3.98, -3.41, -2.69, -3.73],
    # Y
    [                                                 -4.17, -3.36, -3.01, -3.01, -2.78, -2.76, -2.97, -2.76, -2.79, -3.52, -3.16, -2.60, -3.19],
    # A
    [                                                        -2.72, -2.31, -2.32, -2.01, -1.84, -1.89, -1.70, -1.51, -2.41, -1.83, -1.31, -2.03],
    # G
    [                                                               -2.24, -2.08, -1.82, -1.74, -1.66, -1.59, -1.22, -2.15, -1.72, -1.15, -1.87],
    # T
    [                                                                      -2.12, -1.96, -1.88, -1.90, -1.80, -1.74, -2.42, -1.90, -1.31, -1.90],
    # S
    [                                                                             -1.67, -1.58, -1.49, -1.63, -1.48, -2.11, -1.62, -1.05, -1.57],
    # N
    [                                                                                    -1.68, -1.71, -1.68, -1.51, -2.08, -1.64, -1.21, -1.53],
    # Q
    [                                                                                           -1.54, -1.46, -1.42, -1.98, -1.80, -1.29, -1.73],
    # D
    [                                                                                                   -1.21, -1.02, -2.32, -2.29, -1.68, -1.33],
    # E
    [                                                                                                          -0.91, -2.15, -2.27, -1.80, -1.26],
    # H
    [                                                                                                                 -3.05, -2.16, -1.35, -2.25],
    # R
    [                                                                                                                        -1.55, -0.59, -1.70],
    # K
    [                                                                                                                               -0.12, -0.97],
    # P
    [                                                                                                                                      -1.75]
    ]
    MJ = {}
    for i, ai in enumerate(AA):
        for j, aj in enumerate(AA):
            if j < i:  # lower triangle
                MJ[(ai,aj)] = MJ[(aj,ai)]
            else:
                MJ[(ai,aj)] = mj_upper[i][j-i]
    for (a,b), w in list(MJ.items()):
        MJ[(b,a)] = w
    return MJ

def mj_energy(a: str, b: str, MJ: Dict[Tuple[str,str],float]) -> float:
    return MJ.get((a,b), 0.0)

class LatticeQUBO:
    def __init__(self, seq: str, model: str = "HP", max_len: Optional[int]=None):
        self.seq = [c for c in seq.upper() if c.isalpha()]
        if max_len is not None:
            self.seq = self.seq[:max_len]
        if not (5 <= len(self.seq) <= 30):
            raise ValueError("Use a short sequence (5–30 aa) for this demo.")
        self.N = len(self.seq)
        self.positions = reachable_bcc_positions(self.N-1)
        self.index_pos = {p:k for k,p in enumerate(self.positions)}
        self.var2ip: Dict[int, Tuple[int,int]] = {}
        self.ip2var: Dict[Tuple[int,int], int] = {}
        vid = 0
        for i in range(self.N):
            for p_idx in range(len(self.positions)):
                self.var2ip[vid] = (i, p_idx)
                self.ip2var[(i, p_idx)] = vid
                vid += 1
        self.num_vars = vid

        self.model = model.upper()
        if self.model == "HP":
            self.contact_fn = lambda a,b: hp_energy(a,b)
        elif self.model == "MJ":
            self.MJ = load_mj96_table3()
            self.contact_fn = lambda a,b: mj_energy(a,b, self.MJ)
        else:
            raise ValueError("model must be 'HP' or 'MJ'")

        self.adj = {p: set(bcc_neighbors(p)) for p in self.positions}

    def _Q_add(self, Q: Dict[Tuple[int,int], float], u: int, v: int, w: float):
        if w == 0.0: return
        a,b = (u,v) if u<=v else (v,u)
        Q[(a,b)] = Q.get((a,b), 0.0) + w

    def build(self, lam_onehot=6.0, lam_adj=6.0, lam_collision=6.0) -> Dict[Tuple[int,int], float]:
        Q: Dict[Tuple[int,int], float] = {}

        Pn = len(self.positions)
        for i in range(self.N):
            idxs = [self.ip2var[(i, p_idx)] for p_idx in range(Pn)]
            for u, v in itertools.combinations(idxs, 2):
                self._Q_add(Q, u, v, 2.0*lam_onehot)
            
            for u in idxs:
                self._Q_add(Q, u, u, -1.0*lam_onehot)  # from (sum x)^2 - 2 sum x + 1

        P = self.positions
        for i in range(self.N - 1):
            for p_idx, p in enumerate(P):
                nbrs = self.adj[p]
                bad_qs = [q_idx for q_idx, q in enumerate(P) if q not in nbrs]
                u = self.ip2var[(i, p_idx)]
                for q_idx in bad_qs:
                    v = self.ip2var[(i+1, q_idx)]
                    self._Q_add(Q, u, v, lam_adj)

        for p_idx in range(Pn):
            vars_at_p = [self.ip2var[(i, p_idx)] for i in range(self.N)]
            for u, v in itertools.combinations(vars_at_p, 2):
                self._Q_add(Q, u, v, lam_collision)

        for i in range(self.N):
            ai = self.seq[i]
            for j in range(i+2, self.N):
                aj = self.seq[j]
                e_ij = self.contact_fn(ai, aj)
                if e_ij == 0.0:
                    continue
                for p_idx, p in enumerate(P):
                    u = self.ip2var[(i, p_idx)]
                    for q in self.adj[p]:
                        q_idx = self.index_pos.get(q)
                        if q_idx is None: 
                            continue
                        v = self.ip2var[(j, q_idx)]
                        self._Q_add(Q, u, v, e_ij)
        return Q

    def to_qubo_matrix(self, Q: Dict[Tuple[int,int],float]):
        n = self.num_vars
        M = [[0.0]*n for _ in range(n)]
        for (u,v), w in Q.items():
            M[u][v] += w
            if u != v:
                M[v][u] += w
        return M

    def to_qiskit_quadratic_program(self, Q: Dict[Tuple[int,int],float]):
        from qiskit_optimization import QuadraticProgram
        qp = QuadraticProgram()
        for vid in range(self.num_vars):
            qp.binary_var(name=f"x_{vid}")
        linear, quad = {}, {}
        for (u,v), w in Q.items():
            if u == v:
                linear[u] = linear.get(u, 0.0) + w
            else:
                quad[(u,v)] = quad.get((u,v), 0.0) + w
        qp.minimize(linear=linear, quadratic=quad)
        return qp

sequence = "ACDEFGHIK" 
model = "MJ"           

qubo = LatticeQUBO(sequence, model=model)
Q = qubo.build(lam_onehot=12.0, lam_adj=12.0, lam_collision=12.0)

print("Model:", model, "Vars:", qubo.num_vars, "positions:", len(qubo.positions), "N:", qubo.N)
print("Nonzero Q terms:", len(Q))
try:
    qp = qubo.to_qiskit_quadratic_program(Q)
    print("Qiskit QuadraticProgram vars:", qp.get_num_vars(), "| quadratic terms:", len(qp.objective.quadratic.to_dict()))
except Exception as e:
    print("Qiskit export skipped:", e)


Model: MJ Vars: 11169 positions: 1241 N: 9
Nonzero Q terms: 19455185
Qiskit QuadraticProgram vars: 11169 | quadratic terms: 19444016


In [None]:
def build_sparse(Q: Dict[Tuple[int,int], float], num_vars: int):

    h = [0.0]*num_vars
    J = [defaultdict(float) for _ in range(num_vars)]
    for (u,v), w in Q.items():
        if u == v:
            h[u] += w
        else:
            # keep upper-tri in J (store only u<v)
            a,b = (u,v) if u < v else (v,u)
            J[a][b] += w
    return h, J

def energy_from_assignment(h, J, active_vars: List[int]) -> float:

    E = 0.0
    # linear
    for u in active_vars:
        E += h[u]
    # quadratic: only pairs (i<j)
    for i in range(len(active_vars)-1):
        u = active_vars[i]
        for j in range(i+1, len(active_vars)):
            v = active_vars[j]
            a,b = (u,v) if u < v else (v,u)
            E += J[a].get(b, 0.0)
    return E

def delta_energy_move(h, J, active_vars: List[int], i: int, u_old: int, u_new: int) -> float:

    dE = h[u_new] - h[u_old]
    for j, u_j in enumerate(active_vars):
        if j == i: 
            continue
        a,b = (u_new, u_j) if u_new < u_j else (u_j, u_new)
        dE += J[a].get(b, 0.0)
        a,b = (u_old, u_j) if u_old < u_j else (u_j, u_old)
        dE -= J[a].get(b, 0.0)
    return dE
def var_of(qubo, i: int, p_idx: int) -> int:
    return qubo.ip2var[(i, p_idx)]

def pos_index(qubo, pos) -> Optional[int]:
    return qubo.index_pos.get(pos, None)

def indices_to_coords(qubo, active_vars: List[int]) -> List[Tuple[int,int,int]]:
    coords = []
    for i, u in enumerate(active_vars):
        (ii, p_idx) = qubo.var2ip[u]
        assert ii == i
        coords.append(qubo.positions[p_idx])
    return coords

def bcc_are_neighbors(p, q) -> bool:
    dx,dy,dz = q[0]-p[0], q[1]-p[1], q[2]-p[2]
    return abs(dx) == abs(dy) == abs(dz) == 1  # body diagonal


In [None]:
def random_bcc_saw(N: int, max_restarts=2000):
    start = (0,0,0)
    for _ in range(max_restarts):
        path = [start]
        used = {start}
        for _t in range(N-1):
            steps = [(+1,+1,+1),(+1,+1,-1),(+1,-1,+1),(+1,-1,-1),
                     (-1,+1,+1),(-1,+1,-1),(-1,-1,+1),(-1,-1,-1)]
            random.shuffle(steps)
            moved = False
            for dx,dy,dz in steps:
                nxt = (path[-1][0]+dx, path[-1][1]+dy, path[-1][2]+dz)
                if nxt not in used:
                    path.append(nxt); used.add(nxt); moved = True; break
            if not moved:
                break
        if len(path) == N:
            return path
    raise RuntimeError("SAW construction failed; increase restarts or reduce N.")

def project_path_to_variables(qubo, path: List[Tuple[int,int,int]]) -> List[int]:
    active_vars = []
    for i, pos in enumerate(path):
        p_idx = pos_index(qubo, pos)
        if p_idx is None:
            raise RuntimeError(f"Path position {pos} not in workspace; try regenerating the SAW.")
        active_vars.append(var_of(qubo, i, p_idx))
    return active_vars
def check_chain_adjacency(coords: List[Tuple[int,int,int]]) -> bool:
    return all(bcc_are_neighbors(coords[i], coords[i+1]) for i in range(len(coords)-1))

def check_no_collisions(coords: List[Tuple[int,int,int]]) -> bool:
    return len(coords) == len(set(coords))


In [None]:
def anneal_qubo_onehot(qubo, Q, sweeps=200, T_start=5.0, T_end=0.05, neighbor_ratio=0.8, random_ratio=0.2, seed=None):
    """
    Block-move SA that keeps one-hot feasibility.
    - neighbor_ratio: probability to propose moving residue i to a *neighboring* lattice site
    - random_ratio: small chance to jump to a random site (escape)
    """
    if seed is not None:
        random.seed(seed)

    h, J = build_sparse(Q, qubo.num_vars)

    path0 = random_bcc_saw(qubo.N)
    active = project_path_to_variables(qubo, path0)
    E = energy_from_assignment(h, J, active)

    P = qubo.positions
    pos_of = lambda u: qubo.positions[ qubo.var2ip[u][1] ]

    def temperature(t):
        return T_start * (T_end / T_start) ** (t / max(1, sweeps-1))

    best_active = list(active)
    best_E = E

    for t in range(sweeps):
        T = temperature(t)
        order = list(range(qubo.N))
        random.shuffle(order)
        for i in order:
            u_old = active[i]
            p_old_idx = qubo.var2ip[u_old][1]
            p_old = P[p_old_idx]

            proposals = []

            if random.random() < neighbor_ratio:
                for q in bcc_neighbors(p_old):
                    q_idx = qubo.index_pos.get(q)
                    if q_idx is not None:
                        proposals.append(q_idx)

            if not proposals or random.random() < random_ratio:
                for _ in range(4):
                    proposals.append(random.randrange(len(P)))

            random.shuffle(proposals)
            for q_idx in proposals[:8]:  # limit local work
                if q_idx == p_old_idx:
                    continue
                u_new = var_of(qubo, i, q_idx)
                dE = delta_energy_move(h, J, active, i, u_old, u_new)
                if dE <= 0 or random.random() < math.exp(-dE / max(1e-12, T)):
                    # accept
                    active[i] = u_new
                    E += dE
                    if E < best_E:
                        best_E = E
                        best_active = list(active)
                    break  # move to next residue

    return best_active, best_E


In [None]:
best_active, best_E = anneal_qubo_onehot(qubo, Q, sweeps=400, T_start=8.0, T_end=0.02, seed=42)

coords = indices_to_coords(qubo, best_active)
print("Best energy:", round(best_E, 4))
print("Adjacency OK? ", check_chain_adjacency(coords))
print("No collisions?", check_no_collisions(coords))
for i, (aa, xyz) in enumerate(zip(qubo.seq, coords)):
    print(f"{i:2d} {aa}  {xyz}")


Best energy: -108.0
Adjacency OK?  True
No collisions? True
 0 A  (0, 0, 0)
 1 C  (1, -1, -1)
 2 D  (2, -2, -2)
 3 E  (3, -3, -3)
 4 F  (4, -4, -2)
 5 G  (5, -5, -1)
 6 H  (6, -6, -2)
 7 I  (7, -7, -3)
 8 K  (6, -6, -4)


In [None]:
from collections import Counter

def contact_pairs_BCC(coords):
    pairs = []
    N = len(coords)
    for i in range(N):
        for j in range(i+2, N):  # non-consecutive
            p, q = coords[i], coords[j]
            dx,dy,dz = q[0]-p[0], q[1]-p[1], q[2]-p[2]
            if abs(dx) == abs(dy) == abs(dz) == 1:  # BCC neighbor
                pairs.append((i,j))
    return pairs

def breakdown(qubo, coords, lam_onehot, lam_adj, lam_collision):
    P_onehot = 0.0

    P_adj = 0.0
    for i in range(len(coords)-1):
        p, q = coords[i], coords[i+1]
        dx,dy,dz = q[0]-p[0], q[1]-p[1], q[2]-p[2]
        if not (abs(dx)==abs(dy)==abs(dz)==1):
            P_adj += 1.0
    P_adj *= lam_adj

    counts = Counter(coords)
    P_coll = sum((c-1) for c in counts.values() if c>1) * lam_collision

    pairs = contact_pairs_BCC(coords)
    if qubo.model == "HP":
        H = set("AVLIMFWYCGP")
        E_contact = sum(-1.0 for i,j in pairs if qubo.seq[i] in H and qubo.seq[j] in H)
    else:
        MJ = getattr(qubo, "MJ", load_mj96_table3())
        E_contact = sum(MJ.get((qubo.seq[i], qubo.seq[j]), 0.0) for i,j in pairs)

    total = P_onehot + P_adj + P_coll + E_contact
    return dict(total=total, contact=E_contact, onehot=P_onehot, adjacency=P_adj, collision=P_coll, contacts_count=len(pairs))

lam = 12.0  
report = breakdown(qubo, coords, lam, lam, lam)
print("Breakdown:", {k: (round(v,3) if isinstance(v,float) else v) for k,v in report.items()})


Breakdown: {'total': 0.0, 'contact': 0, 'onehot': 0.0, 'adjacency': 0.0, 'collision': 0.0, 'contacts_count': 0}


In [None]:
import multiprocessing as mp
import time

def worker_run(args):
    seed, sweeps, T0, T1 = args
    best_active, best_E = anneal_qubo_onehot(qubo, Q, sweeps=sweeps, T_start=T0, T_end=T1, seed=seed)
    return seed, best_E, best_active

def multi_start_sa(num_runs=64, sweeps=400, T0=8.0, T1=0.02):
    seeds = [1337 + k for k in range(num_runs)]
    jobs = [(s, sweeps, T0, T1) for s in seeds]
    t0 = time.time()
    with mp.Pool() as pool:
        results = pool.map(worker_run, jobs)
    dt = time.time() - t0
    best = min(results, key=lambda r: r[1])
    print(f"Completed {num_runs} runs in {dt:.1f}s; best seed {best[0]} energy {best[1]:.3f}")
    return best, results

best, all_results = multi_start_sa(num_runs=64, sweeps=600, T0=10.0, T1=0.01)
best_seed, bestE, best_active = best
best_coords = indices_to_coords(qubo, best_active)

print("Best multi-start energy:", round(bestE,3))
print("Adj OK?", check_chain_adjacency(best_coords), "No collisions?", check_no_collisions(best_coords))
