In [1]:
from QG_functions import *

import numpy as np
import pandas as pd

from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core.periodic_table import Element
from pymatgen.io.cif import *

from ase.visualize import view


from pymatgen.io.ase import AseAtomsAdaptor
import sys

import re
import shutil as sh
import pickle
from tqdm import tqdm


import copy
from sklearn.metrics import mean_squared_error 

#import dataframe_image as dfi

from scipy import constants
from scipy.spatial import KDTree, distance_matrix

import matplotlib.pyplot as plt

import itertools
from itertools import chain

from sklearn.linear_model import LinearRegression
# from sklearn.metrics import mean_squared_error as mse


k_b = constants.physical_constants['Boltzmann constant in eV/K'][0]
# print(k_b)
def vview(structure):
    view(AseAtomsAdaptor().get_atoms(structure))

np.seterr(divide='ignore')
plt.style.use('tableau-colorblind10')

import seaborn as sns
import time
# from QG_functions import *

In [2]:
fully_lithiated_structure_init = Structure.from_file('data/fully_lithiated_tmp.cif')
delithiated_structure_init = Structure.from_file('data/delithiated_tmp.cif')
n_sites = fully_lithiated_structure_init.num_sites
fully_lithiated_structure_init.translate_sites(np.arange(n_sites),[1,1,1],to_unit_cell=True)

# fully_lithiated_structure_init.replace_species({'Li':'He'})

# fully_lithiated_structure_init.replace_species({'Tc':'Mn'})
vview(fully_lithiated_structure_init)

## Ewald

#### pymatgen

In [3]:
from pymatgen.core import Structure
from pymatgen.analysis.ewald import EwaldSummation

# load or define a Structure (with oxidation states!)
structure = copy.deepcopy(fully_lithiated_structure_init)
structure.add_oxidation_state_by_element({"Tc": +3, "O": -2,"Li": +1})
# ensure each site has a charge (oxidation_state)
# e.g., structure.add_oxidation_state_by_element({'Na': +1, 'Cl': -1})

# create EwaldSummation object
ewald = EwaldSummation(structure, eta=None, w=1)

# total electrostatic energy (eV)
total_energy = ewald.total_energy

ewald.total_energy_matrix


array([[-1.9664096 , -0.02659309, -0.29047144, ...,  0.75373109,
        -0.06390126,  0.26197271],
       [-0.02659309, -1.9664096 , -0.25817048, ...,  0.59371889,
         0.26197271,  0.52282942],
       [-0.29047144, -0.25817048, -1.9664096 , ..., -1.78840077,
         0.55450697, -0.13948516],
       ...,
       [ 0.75373109,  0.59371889, -1.78840077, ..., -7.86563842,
        -1.13648963, -0.03764788],
       [-0.06390126,  0.26197271,  0.55450697, ..., -1.13648963,
        -7.86563842,  2.53286222],
       [ 0.26197271,  0.52282942, -0.13948516, ..., -0.03764788,
         2.53286222, -7.86563842]], shape=(96, 96))

## Workflow
Define intial structure (initial grid, equally spaced)
Ewald+Buckingham discrete (W)
Map W to CP-SAT + constraints (one-hot and composition)
Solve CP-SAT
GULP optimise N low energy structures
Analyse geometry and build new grid


## Being new full script

In [63]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
from full_script_functions import *

N_li = 2

N_initial_grid=100
min_dist_grid=1.

threshold_li=1.5
prox_penalty=1000

one_hot_value = 200
weight = 500

N_structures_opt = 2

number_iterations = 1000
number_runs = 100

input_name='gulp_klmc.gin'
gulp_io_path='klmc/'
mace_io_path='mace_io_files'

M = 20 #grid definition
N_positions_final = 100

threshold = 0.05  # THIS IS AN IMPORTANT PARAMETER

num_iterations = 5

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Build CP-SAT model

In [65]:
from ortools.sat.python import cp_model

def build_site_option_maps_from_indices(li_indices, mn_indices):
    """
    Input:
      li_indices: list[int]  -> QUBO columns that mean 'Li present'
      mn_indices: list[int]  -> [Mn4_0, Mn3_0, Mn4_1, Mn3_1, ...]
    Output:
      site_options: dict[site_id] -> list[str] of options
      var2siteopt: dict[qubo_col] -> (site_id, option_name)
      li_sites: list[int] of site_ids that are Li grid sites
      mn_sites: list[int] of site_ids that are Mn sites
    """
    assert len(mn_indices) % 2 == 0, "mn_indices must be even length (pairs)."

    site_options = {}
    var2siteopt  = {}
    li_sites, mn_sites = [], []
    print(len(li_indices))
    # 1) Li grid sites: create a site with options ["Empty","Li"] for each li_index
    for k in li_indices:
        s = len(site_options)
        site_options[s] = ["Empty", "Li"]
        var2siteopt[k]  = (s, "Li")      # the QUBO var corresponds to the "Li" option
        li_sites.append(s)
    print(len(mn_indices))
    # 2) Mn sites: every consecutive pair -> one Mn site with ["Mn4","Mn3"]
    for p in range(0, len(mn_indices), 2):
        k4 = mn_indices[p]
        k3 = mn_indices[p+1]
        s = len(site_options)
        site_options[s] = ["Mn4", "Mn3"]
        var2siteopt[k4] = (s, "Mn4")
        var2siteopt[k3] = (s, "Mn3")
        mn_sites.append(s)

    return site_options, var2siteopt, li_sites, mn_sites

site_options, var2siteopt, li_sites, mn_sites = build_site_option_maps_from_indices(li_indices, mn_indices)

65
48


In [66]:
from ortools.sat.python import cp_model
def build_x_vars_and_onehot(model: cp_model.CpModel, site_options):
    """
    Make BoolVars x[(s,a)] and add one-hot per site: sum_a x[s,a] == 1
    Returns: x dict
    """
    x = {}
    for s, opts in site_options.items():
        for a in opts:
            x[(s, a)] = model.NewBoolVar(f"x_{s}_{a}")
        # one-hot: exactly one option per site
        model.Add(sum(x[(s, a)] for a in opts) == 1)
    return x

model = cp_model.CpModel()
build_x_vars_and_onehot(model, site_options)

{(0, 'Empty'): x_0_Empty(0..1),
 (0, 'Li'): x_0_Li(0..1),
 (1, 'Empty'): x_1_Empty(0..1),
 (1, 'Li'): x_1_Li(0..1),
 (2, 'Empty'): x_2_Empty(0..1),
 (2, 'Li'): x_2_Li(0..1),
 (3, 'Empty'): x_3_Empty(0..1),
 (3, 'Li'): x_3_Li(0..1),
 (4, 'Empty'): x_4_Empty(0..1),
 (4, 'Li'): x_4_Li(0..1),
 (5, 'Empty'): x_5_Empty(0..1),
 (5, 'Li'): x_5_Li(0..1),
 (6, 'Empty'): x_6_Empty(0..1),
 (6, 'Li'): x_6_Li(0..1),
 (7, 'Empty'): x_7_Empty(0..1),
 (7, 'Li'): x_7_Li(0..1),
 (8, 'Empty'): x_8_Empty(0..1),
 (8, 'Li'): x_8_Li(0..1),
 (9, 'Empty'): x_9_Empty(0..1),
 (9, 'Li'): x_9_Li(0..1),
 (10, 'Empty'): x_10_Empty(0..1),
 (10, 'Li'): x_10_Li(0..1),
 (11, 'Empty'): x_11_Empty(0..1),
 (11, 'Li'): x_11_Li(0..1),
 (12, 'Empty'): x_12_Empty(0..1),
 (12, 'Li'): x_12_Li(0..1),
 (13, 'Empty'): x_13_Empty(0..1),
 (13, 'Li'): x_13_Li(0..1),
 (14, 'Empty'): x_14_Empty(0..1),
 (14, 'Li'): x_14_Li(0..1),
 (15, 'Empty'): x_15_Empty(0..1),
 (15, 'Li'): x_15_Li(0..1),
 (16, 'Empty'): x_16_Empty(0..1),
 (16, 'Li'): x

#### One-hot


In [67]:
def add_li_count_constraint(model: cp_model.CpModel, x, li_sites, N_li: int):
    # Sum of "Li" selections over Li sites equals the target number
    model.Add(sum(x[(s, "Li")] for s in li_sites) == N_li)


#### Charge Balance Constraints


In [68]:
def add_li_mn_charge_balance_constraints(model: cp_model.CpModel, x, li_sites, mn_sites, N_li: int):
    """
    Enforces:
      - total number of Li atoms == N_li
      - total number of Mn3+ ions == N_li
    """
    # Li count constraint
    model.Add(sum(x[(s, "Li")] for s in li_sites) == N_li)

    # Mn3+ count constraint
    model.Add(sum(x[(s, "Mn3")] for s in mn_sites) == N_li)


#### Li-proximity constraint


In [69]:
def add_li_proximity_exclusions(model: cp_model.CpModel, x, forbidden_pairs_site_ids):
    # For any pair of Li sites (s,t) that are too close: not both Li
    for (s, t) in forbidden_pairs_site_ids:
        model.Add(x[(s, "Li")] + x[(t, "Li")] <= 1)

In [103]:
def _frac_coords(coords_cart, lattice):
    """Convert cartesian coords to fractional with given 3x3 lattice matrix."""
    return np.dot(coords_cart, np.linalg.inv(lattice).T)

def _mic_delta_frac(df):
    """Minimum-image wrap for fractional deltas to (-0.5, 0.5]."""
    return df - np.round(df)

def _pair_edges_with_threshold(frac_coords, lattice, threshold):
    """
    Build edges (i,j) where PBC distance < threshold.
    frac_coords: (N,3) fractional coords in [0,1)
    lattice: 3x3 cartesian lattice matrix
    """
    N = len(frac_coords)
    edges = []
    L = np.asarray(lattice)
    for i in range(N):
        for j in range(i+1, N):
            df = _mic_delta_frac(frac_coords[j] - frac_coords[i])
            dcart = df @ L
            if np.linalg.norm(dcart) < threshold:
                edges.append((i, j))
    return edges

def _bron_kerbosch(R, P, X, adj, cliques, max_nodes=1000):
    """Simple Bron–Kerbosch to enumerate maximal cliques (no pivot)."""
    # small safeguard
    if len(cliques) > max_nodes:
        return
    if not P and not X:
        if len(R) >= 2:
            cliques.append(tuple(sorted(R)))
        return
    # iterate over a copy since P will mutate
    for v in list(P):
        _bron_kerbosch(R | {v}, P & adj[v], X & adj[v], adj, cliques, max_nodes)
        P.remove(v)
        X.add(v)

def _maximal_cliques_from_edges(N, edges, cap_cliques=10000):
    """Return list of maximal cliques (each a tuple of node indices)."""
    # adjacency as sets
    adj = [set() for _ in range(N)]
    for i, j in edges:
        adj[i].add(j); adj[j].add(i)
    cliques = []
    _bron_kerbosch(set(), set(range(N)), set(), adj, cliques, max_nodes=cap_cliques)
    return cliques

def build_li_proximity_groups(
    li_grid_coords,
    threshold_ang,
    *,
    lattice=None,
    coords_are_cartesian=True,
    site_ids=None,
    return_pairs_also=True,
):
    """
    Build Li–Li proximity groups for CP-SAT 'AtMostOne' constraints.

    Parameters
    ----------
    li_grid_coords : (M,3) array
        Li candidate site coordinates. If 'coords_are_cartesian' is False,
        they are treated as fractional.
    threshold_ang : float
        Distance threshold (Å) for "too close".
    lattice : (3,3) array-like or pymatgen Lattice, optional
        Required if periodic MIC distances matter. If None and
        coords_are_cartesian=True, uses plain Euclidean distances (no PBC).
    coords_are_cartesian : bool
        Whether li_grid_coords are in Cartesian. If True and lattice is given,
        MIC distances are used. If False, coords are treated as fractional.
    site_ids : list[int], optional
        CP site IDs aligned with li_grid_coords. If None, uses range(M).
    return_pairs_also : bool
        If True, also return the raw list of close pairs.

    Returns
    -------
    groups : list[list[int]]
        Each group is a clique (size ≥ 2) of site IDs that are mutually
        within the threshold (use one AddAtMostOne per group).
    pairs  : list[tuple[int,int]]  (only if return_pairs_also=True)
        All offending pairs (site_i, site_j) by ID.
    """
    coords = np.asarray(li_grid_coords, dtype=float)
    M = coords.shape[0]
    if site_ids is None:
        site_ids = list(range(M))

    # Build fractional coords (needed for MIC)
    if coords_are_cartesian:
        if lattice is None:
            # no PBC: simple Euclidean threshold
            # build edges directly
            edges = []
            for i in range(M):
                for j in range(i+1, M):
                    if np.linalg.norm(coords[j] - coords[i]) < threshold_ang:
                        edges.append((i, j))
        else:
            L = lattice.matrix if hasattr(lattice, "matrix") else np.asarray(lattice, dtype=float)
            f = _frac_coords(coords, L)
            edges = _pair_edges_with_threshold(f % 1.0, L, threshold_ang)
    else:
        # coords already fractional
        if lattice is None:
            raise ValueError("lattice is required when using fractional coords to compute distances.")
        L = lattice.matrix if hasattr(lattice, "matrix") else np.asarray(lattice, dtype=float)
        edges = _pair_edges_with_threshold(coords % 1.0, L, threshold_ang)

    # Maximal cliques from the proximity graph
    cliques = _maximal_cliques_from_edges(M, edges)

    # Map internal indices -> site_ids
    groups = [[site_ids[i] for i in clique] for clique in cliques]
    if return_pairs_also:
        pairs = [(site_ids[i], site_ids[j]) for (i, j) in edges]
        return groups, pairs
    return groups

In [None]:
def add_li_proximity_exclusions(model: cp_model.CpModel, x, proximity_groups):
    """
    Enforce Li–Li exclusion constraints.

    Parameters
    ----------
    model : cp_model.CpModel
        The model to which constraints are added.
    x : dict
        Variable dictionary {(site_id, option): BoolVar}.
    proximity_groups : list
        Each element is either:
            - a tuple/list of two site IDs (s, t)  → pairwise exclusion
            - a list/tuple of ≥2 site IDs forming a clique (all mutually too close)

    Effect
    ------
    For every group g of sites, ensures that at most one can be occupied by Li:
        sum_{s∈g} x[s, "Li"] ≤ 1
    Uses CP-SAT’s AddAtMostOne for stronger propagation.
    """
    num_groups = 0
    for g in proximity_groups:
        # Normalise input
        if isinstance(g, tuple) or isinstance(g, list):
            sites = list(g)
        else:
            raise ValueError("Each proximity group must be a list/tuple of site IDs.")
        if len(sites) < 2:
            continue

        # Collect the BoolVars for these sites
        li_vars = [x[(s, "Li")] for s in sites]
        model.AddAtMostOne(li_vars)
        num_groups += 1

    # print(f"Added {num_groups} Li–Li exclusion groups (AtMostOne).")
    # return model

### Add constraints together

In [None]:
def cpsat_core_from_indices(li_indices, mn_indices, N_li, proximity_groups=[]):
    model = cp_model.CpModel()

    # A) sites & options
    site_options, var2siteopt, li_sites, mn_sites = build_site_option_maps_from_indices(
        li_indices, mn_indices
    )

    # B) x vars + one-hot per site
    x = build_x_vars_and_onehot(model, site_options)

    # C) Li constraints
    add_li_mn_charge_balance_constraints(model, x, li_sites, mn_sites, N_li)
    add_li_proximity_exclusions(model, x, proximity_groups)

    # (No objective yet; we’ll add it when we map your pair energies W)
    return model, x, site_options, var2siteopt, li_sites, mn_sites

In [71]:
model, x, site_options, var2siteopt, li_sites, mn_sites = cpsat_core_from_indices(li_indices, mn_indices, N_li=2)

65
48


## Add objective

In [72]:
import numpy as np
from ortools.sat.python import cp_model

def add_ut_qubo_objective(
    model: cp_model.CpModel,
    x: dict,                  # {(site_id, option_name): BoolVar}
    var2siteopt: dict,        # {qubo_col_index: (site_id, option_name)}
    Q_ut: np.ndarray,         # upper-triangular QUBO matrix (shape n x n)
    *,
    scale: float = 1000.0,    # integer scaling for CP-SAT
    tiny: float = 1e-12,
    name_prefix: str = "y"
):
    """
    Add a minimization objective equivalent to an *upper-triangular* QUBO.

    Energy = sum_i     Q[i,i] * x_(s,a)
           + sum_{i<j} Q[i,j] * (x_(s,a) AND x_(t,b))

    where Q's columns/rows index your original binary vars (Li, Mn4, Mn3),
    and var2siteopt maps each original var index -> (site, option).

    Notes:
    - We only iterate i<=j because Q is upper-triangular (no double counting).
    - Same-site off-diagonals (i<j but s==t) are skipped (redundant under one-hot).
    - Coefficients are scaled to integers for CP-SAT.
    """
    n = Q_ut.shape[0]

    # 1) Integerize (keep upper triangle semantics)
    Q = np.array(Q_ut, dtype=float, copy=True)
    Qi = np.rint(Q * scale).astype(int)
    # prune tiny
    Qi[np.abs(Qi) < tiny] = 0
    SCALE = int(scale)

    obj_terms = []
    num_diag_added = 0
    num_pairs_added = 0
    num_pairs_skipped_same_site = 0
    num_pairs_skipped_zero = 0
    num_pairs_total_seen = 0

    # 2) Diagonal: Q[i,i] * x_(s,a)
    for i in range(n):
        if i not in var2siteopt:
            continue
        s, a = var2siteopt[i]
        c = Qi[i, i]
        if c != 0:
            obj_terms.append(c * x[(s, a)])
            num_diag_added += 1

    # 3) Off-diagonals (upper triangle): Q[i,j] * y, with y = AND(x_(s,a), x_(t,b))
    for i in range(n):
        if i not in var2siteopt:
            continue
        s, a = var2siteopt[i]
        row = Qi[i]
        for j in range(i + 1, n):      # i<j, upper-triangular entries only
            num_pairs_total_seen += 1
            c = row[j]
            if c == 0:
                num_pairs_skipped_zero += 1
                continue
            if j not in var2siteopt:
                continue
            t, b = var2siteopt[j]
            if s == t:
                # cross-terms within the same physical site are redundant with one-hot
                num_pairs_skipped_same_site += 1
                continue

            y = model.NewBoolVar(f"{name_prefix}_{s}_{a}_{t}_{b}")
            model.Add(y <= x[(s, a)])
            model.Add(y <= x[(t, b)])
            model.Add(y >= x[(s, a)] + x[(t, b)] - 1)
            obj_terms.append(c * y)
            num_pairs_added += 1

    # 4) Set objective
    model.Minimize(sum(obj_terms))

    # 5) Diagnostics
    summary = {
        "num_diag_added": num_diag_added,
        "num_pairs_total_seen": num_pairs_total_seen,
        "num_pairs_added": num_pairs_added,
        "num_pairs_skipped_zero": num_pairs_skipped_zero,
        "num_pairs_skipped_same_site": num_pairs_skipped_same_site,
        "scale": SCALE,
    }
    return SCALE, summary

In [73]:
SCALE, info = add_ut_qubo_objective(model, x, var2siteopt, QUBO)  # QUBO is upper-triangular

## Solve

In [74]:
SCALE, info = add_ut_qubo_objective(model, x, var2siteopt, QUBO)  # QUBO is upper-triangular
print("Objective wiring:", info)

solver = cp_model.CpSolver()
solver.parameters.use_lns = True
solver.parameters.num_search_workers = 8
solver.parameters.max_time_in_seconds = 180
status = solver.Solve(model)
print("Status:", solver.StatusName(status))
if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
    for s, opts in site_options.items():
        chosen = [a for a in opts if solver.Value(x[(s,a)])==1][0]
        print(f"Site {s}: {chosen}")
    print("Energy (eV):", solver.ObjectiveValue() / SCALE)

Objective wiring: {'num_diag_added': 113, 'num_pairs_total_seen': 6328, 'num_pairs_added': 6304, 'num_pairs_skipped_zero': 24, 'num_pairs_skipped_same_site': 0, 'scale': 1000}
Status: OPTIMAL
Site 0: Empty
Site 1: Empty
Site 2: Empty
Site 3: Empty
Site 4: Empty
Site 5: Empty
Site 6: Empty
Site 7: Empty
Site 8: Empty
Site 9: Empty
Site 10: Empty
Site 11: Empty
Site 12: Empty
Site 13: Li
Site 14: Empty
Site 15: Empty
Site 16: Empty
Site 17: Empty
Site 18: Empty
Site 19: Empty
Site 20: Empty
Site 21: Empty
Site 22: Empty
Site 23: Empty
Site 24: Empty
Site 25: Empty
Site 26: Empty
Site 27: Li
Site 28: Empty
Site 29: Empty
Site 30: Empty
Site 31: Empty
Site 32: Empty
Site 33: Empty
Site 34: Empty
Site 35: Empty
Site 36: Empty
Site 37: Empty
Site 38: Empty
Site 39: Empty
Site 40: Empty
Site 41: Empty
Site 42: Empty
Site 43: Empty
Site 44: Empty
Site 45: Empty
Site 46: Empty
Site 47: Empty
Site 48: Empty
Site 49: Empty
Site 50: Empty
Site 51: Empty
Site 52: Empty
Site 53: Empty
Site 54: Empty

In [76]:
class IncumbentSaver(cp_model.CpSolverSolutionCallback):
    """
    Saves every feasible solution (incumbent) the CP-SAT search discovers.
    - Writes a JSON with assignment & energy.
    - Optionally builds and saves a structure file via `build_structure(assignment)`.
    """
    def __init__(self, x, site_options, *, scale=1.0, out_dir="solutions",
                 build_structure=None, struct_fmt="cif", limit=None, print_every=1):
        super().__init__()
        self.x = x                              # dict[(s,a)] -> BoolVar
        self.site_options = site_options        # dict[s] -> [options...]
        self.scale = scale
        self.out_dir = out_dir
        self.build_structure = build_structure  # callable: dict[s]->option -> pymatgen.Structure
        self.struct_fmt = struct_fmt            # e.g. "cif"
        self.limit = limit                      # max number of solutions to save (None = unlimited)
        self.print_every = print_every
        self.solution_count = 0
        os.makedirs(out_dir, exist_ok=True)

    def on_solution_callback(self):
        if self.limit is not None and self.solution_count >= self.limit:
            return

        # Decode assignment: for each site pick the unique option with value 1
        assign = { s: next(a for a in opts if self.Value(self.x[(s,a)]) == 1)
                   for s, opts in self.site_options.items() }

        # Energy (unscale back to eV if you used scaling in the objective)
        energy = self.ObjectiveValue() / self.scale if self.ObjectiveValue() is not None else None

        # Save JSON
        ts = time.strftime("%Y%m%d-%H%M%S")
        idx = self.solution_count
        base = os.path.join(self.out_dir, f"sol_{idx:05d}")
        meta = {
            "index": idx,
            "timestamp": ts,
            "energy": energy,
            "assignment": assign
        }
        with open(base + ".json", "w") as f:
            json.dump(meta, f, indent=2)

        # Optional: build and write a structure file
        if callable(self.build_structure):
            try:
                struct = self.build_structure(assign)  # must return a pymatgen.Structure
                # You can pick filename scheme as you like:
                struct.to(filename=base + "." + self.struct_fmt)
            except Exception as e:
                # Don’t crash the solver; just log an error stub
                with open(base + ".err.txt", "w") as f:
                    f.write(f"Structure build failed: {repr(e)}\n")

        # Optional console progress
        if (idx % self.print_every) == 0:
            print(f"[incumbent {idx}] E = {energy:.6f} eV")

        self.solution_count += 1

In [97]:
from ortools.sat.python import cp_model
import hashlib, json, time, gzip, os

class StreamingIncumbentSaver(cp_model.CpSolverSolutionCallback):
    def __init__(self, x, site_options, li_sites, mn_sites, scale, out_dir, limit=None):
        super().__init__()
        self.x = x
        self.site_options = site_options
        self.li_sites = li_sites
        self.mn_sites = mn_sites
        self.scale = scale
        self.out_dir = out_dir
        self.limit = limit
        self.count = 0
        os.makedirs(out_dir, exist_ok=True)
        self.inc_path = os.path.join(out_dir, "incumbents.jsonl.gz")

    def on_solution_callback(self):
        if self.limit is not None and self.count >= self.limit:
            return
        # Decode current incumbent
        assignment = {s: next(a for a in opts if self.Value(self.x[(s,a)]) == 1)
                      for s, opts in self.site_options.items()}
        E = None if self.ObjectiveValue() is None else self.ObjectiveValue() / self.scale

        # Minimal record (same shape as append_incumbent)
        li_on  = sorted(int(s) for s in self.li_sites if assignment[s] == "Li")
        mn3_on = sorted(int(s) for s in self.mn_sites if assignment[s] == "Mn3")
        cfg_bytes = json.dumps({"li_on": li_on, "mn3_on": mn3_on}, separators=(",", ":")).encode()
        cfg_hash = hashlib.sha256(cfg_bytes).hexdigest()[:16]

        rec = {
            "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "E": E, "li_on": li_on, "mn3_on": mn3_on,
            "n_li": len(li_on), "n_mn3": len(mn3_on),
            "cfg": cfg_hash, "tags": {"status": "INCUMBENT"}
        }
        with gzip.open(self.inc_path, "ab") as gz:
            gz.write((json.dumps(rec, separators=(",", ":")) + "\n").encode("utf-8"))

        self.count += 1

In [100]:
import os, json, gzip, hashlib, time
import numpy as np

# ===============================================================
#   Initialize run folder (geometry, QUBO, metadata)
# ===============================================================
def _sha256_of_arrays(*arrays) -> str:
    h = hashlib.sha256()
    for a in arrays:
        h.update(np.ascontiguousarray(a).tobytes())
    return h.hexdigest()[:16]  # short hash for filenames/IDs


def init_run_store(
    output_dir: str,
    initial_structure,                 # pymatgen.Structure (framework, no Li)
    li_sites: list,                    # CP site IDs for Li grid
    mn_sites: list,                    # CP site IDs for Mn
    initial_grid_cart: np.ndarray,     # (M,3) Li grid in CARTESIAN
    mn_atom_indices: list,             # len == len(mn_sites); atom indices in initial_structure
    QUBO_ut: np.ndarray,               # upper-triangular QUBO matrix (n x n)
    SCALE: int,                        # integer scaling used in objective
    solver_params: dict,               # e.g. {"time":180,"workers":8,"seed":42}
    extra_meta: dict = None,           # optional extra info
):
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "energy_model"), exist_ok=True)

    # --- Geometry data ---
    lat = initial_structure.lattice.matrix.astype(np.float32)
    species_Z = np.array([sp.Z for sp in initial_structure.species], dtype=np.int16)
    frac_coords = initial_structure.frac_coords.astype(np.float32)
    li_grid_frac = initial_structure.lattice.get_fractional_coords(initial_grid_cart).astype(np.float32)

    # Save geometry
    np.savez_compressed(
        os.path.join(output_dir, "geometry.npz"),
        lattice=lat,
        species_Z=species_Z,
        frac_coords=frac_coords,
        li_grid_frac=li_grid_frac,
        mn_atom_indices=np.array(mn_atom_indices, dtype=np.int32),
    )

    # Save mapping
    with open(os.path.join(output_dir, "mapping.json"), "w") as f:
        json.dump(
            {"li_sites": list(map(int, li_sites)), "mn_sites": list(map(int, mn_sites))},
            f,
            indent=2,
        )

    # Save QUBO model
    np.savez_compressed(
        os.path.join(output_dir, "energy_model", "qubo_ut.npz"),
        Q_ut=QUBO_ut.astype(np.float32),
        SCALE=int(SCALE),
    )

    # --- Hashes for provenance ---
    geom_hash = _sha256_of_arrays(lat, species_Z, frac_coords, li_grid_frac)
    qubo_hash = _sha256_of_arrays(QUBO_ut)

    # --- Meta info ---
    meta = {
        "run_id": os.path.basename(output_dir),
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "solver_params": solver_params,
        "SCALE": int(SCALE),
        "geom_hash": geom_hash,
        "qubo_hash": qubo_hash,
        "files": {
            "geometry": "geometry.npz",
            "mapping": "mapping.json",
            "qubo": "energy_model/qubo_ut.npz",
            "incumbents": "incumbents.jsonl.gz",
        },
    }
    if extra_meta:
        meta.update(extra_meta)

    with open(os.path.join(output_dir, "meta.json"), "w") as f:
        json.dump(meta, f, indent=2)

    return {"geom_hash": geom_hash, "qubo_hash": qubo_hash}


# ===============================================================
#   Append an incumbent (single configuration)
# ===============================================================
def append_incumbent(
    output_dir: str,
    assignment: dict,            # {site_id: option_name}
    energy_ev: float | None,
    *,
    li_sites: list,
    mn_sites: list,
    tags: dict = None           # optional metadata, e.g. {"status":"FINAL"}
):
    """Append one incumbent configuration to incumbents.jsonl.gz."""
    li_on = sorted(int(s) for s in li_sites if assignment[s] == "Li")
    mn3_on = sorted(int(s) for s in mn_sites if assignment[s] == "Mn3")

    # Stable short hash for deduplication
    cfg_bytes = json.dumps({"li_on": li_on, "mn3_on": mn3_on},
                           separators=(",", ":")).encode("utf-8")
    cfg_hash = hashlib.sha256(cfg_bytes).hexdigest()[:16]

    rec = {
        "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "E": None if energy_ev is None else float(energy_ev),
        "li_on": li_on,
        "mn3_on": mn3_on,
        "n_li": len(li_on),
        "n_mn3": len(mn3_on),
        "cfg": cfg_hash,
    }
    if tags:
        rec["tags"] = tags

    inc_path = os.path.join(output_dir, "incumbents.jsonl.gz")
    with gzip.open(inc_path, "ab") as gz:
        gz.write((json.dumps(rec, separators=(",", ":")) + "\n").encode("utf-8"))

    return cfg_hash

In [None]:
# ======= YOUR LOOP (fixed order + small corrections) =======
# Assumes these are defined in your environment:
#   generate_filtered_grid, join_structure_grid, build_QUBO,
#   build_site_option_maps_from_indices, cpsat_core_from_indices, add_ut_qubo_objective
# And you have set: N_initial_grid, min_dist_grid, threshold_li, prox_penalty, N_structures_opt

initial_structure = Structure.from_file('data/delithiated_tmp.cif')

initial_grid = generate_filtered_grid(
    initial_structure,
    N_initial_grid=N_initial_grid,
    min_dist_grid=min_dist_grid
)
structure = join_structure_grid(initial_structure, initial_grid)
grid = copy.deepcopy(initial_grid)

num_iterations = 1
for i in range(num_iterations):
    print(f'************ Begin Iteration {i} ************')
    output_dir = f'output_folder_{i}'

    # --- Build energy matrix (QUBO) and mapping ---
    QUBO, li_indices, mn_indices = build_QUBO(
        structure,
        threshold_li=threshold_li,
        prox_penalty=prox_penalty
    )

    groups, pairs = build_li_proximity_groups(
        li_grid_coords=grid,              # (M,3) Cartesian
        threshold_ang=1.8,                # your cutoff
        lattice=initial_structure.lattice, # for PBC-aware distances
        coords_are_cartesian=True,
        site_ids=li_sites                 # CP site IDs aligned to grid order
        )

    # CP-SAT core (one-hot, counts, charge balance)
    model, x, site_options, var2siteopt, li_sites, mn_sites = cpsat_core_from_indices(
        li_indices, mn_indices, N_li=2
    )

    # Objective from upper-triangular QUBO
    SCALE, info = add_ut_qubo_objective(model, x, var2siteopt, QUBO)
    print("Objective wiring:", info)

    # --- Mn atom indices (Z=25). Align to mn_sites order if needed ---
    mn_atom_indices_all = np.where(np.array(initial_structure.atomic_numbers) == 25)[0]
    assert len(mn_atom_indices_all) >= len(mn_sites), \
        "Not enough Mn atoms in initial_structure to map mn_sites."
    mn_atom_indices = list(mn_atom_indices_all[:len(mn_sites)])

    # --- Save run-level artifacts once ---
    solver_params = {"time": 180, "workers": 8, "seed": 42}
    _ = init_run_store(
        output_dir=output_dir,
        initial_structure=initial_structure,
        li_sites=li_sites,
        mn_sites=mn_sites,
        initial_grid_cart=grid,
        mn_atom_indices=mn_atom_indices,
        QUBO_ut=QUBO,
        SCALE=SCALE,
        solver_params=solver_params,
    )

    # --- Solver setup ---
    solver = cp_model.CpSolver()
    solver.parameters.max_time_in_seconds = 180
    solver.parameters.num_search_workers = 8
    solver.parameters.random_seed = 42
    solver.parameters.log_search_progress = True
    if hasattr(solver.parameters, "use_lns"):
        solver.parameters.use_lns = True

    # --- Create and attach incumbent saver callback ---
    cb = StreamingIncumbentSaver(
        x=x,
        site_options=site_options,
        li_sites=li_sites,
        mn_sites=mn_sites,
        scale=SCALE,
        out_dir=output_dir,
        limit=500,  # optional cap on saved incumbents
    )

    # --- Solve with callback ---
    status = solver.Solve(model, cb)
    print("Status:", solver.StatusName(status))
    print("Incumbents saved during search:", cb.count)

    if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
        # Decode final best assignment
        assignment = {s: next(a for a in opts if solver.Value(x[(s, a)]) == 1)
                      for s, opts in site_options.items()}
        try:
            best_E = solver.ObjectiveValue() / SCALE
        except Exception:
            best_E = None

        # Save terminal incumbent (tagged as FINAL)
        cfg_hash = append_incumbent(
            output_dir=output_dir,
            assignment=assignment,
            energy_ev=best_E,
            li_sites=li_sites,
            mn_sites=mn_sites,
            tags={"status": "FINAL", "solver_status": solver.StatusName(status)}
        )
        print(f"Saved final incumbent cfg: {cfg_hash}, E = {best_E:.3f} eV")
    else:
        print("No feasible solution found.")

    print(f'************ End Iteration {i} ************\n')

************ Begin Iteration 0 ************


Buckingham matrix: 100%|██████████| 137/137 [00:09<00:00, 15.11it/s]


65
48
Objective wiring: {'num_diag_added': 113, 'num_pairs_total_seen': 6328, 'num_pairs_added': 6304, 'num_pairs_skipped_zero': 24, 'num_pairs_skipped_same_site': 0, 'scale': 1000}

Starting CP-SAT solver v9.14.6206
Parameters: random_seed: 42 max_time_in_seconds: 180 log_search_progress: true num_search_workers: 8 use_lns: true

Initial optimization model '': (model_fingerprint: 0x343327f82b2a62d1)
#Variables: 6'482 (#bools: 6'417 in objective) (87 primary variables)
  - 6'482 Booleans in [0,1]
#kLinear2: 12'697
#kLinear3: 6'304
#kLinearN: 2 (#terms: 89)

Starting presolve at 0.01s
  1.08e-03s  0.00e+00d  [DetectDominanceRelations] 
  2.85e-02s  0.00e+00d  [PresolveToFixPoint] #num_loops=2 #num_dual_strengthening=1 
  3.20e-05s  0.00e+00d  [ExtractEncodingFromLinear] 
  7.29e-04s  0.00e+00d  [DetectDuplicateColumns] 
  1.74e-03s  0.00e+00d  [DetectDuplicateConstraints] 
[Symmetry] Graph for symmetry has 25'574 nodes and 57'003 arcs.
[Symmetry] Symmetry computation done. time: 0.00247

In [102]:
import gzip, json

# Example: read all incumbents for iteration 0
with gzip.open("output_folder_0/incumbents.jsonl.gz", "rt") as f:
    for line in f:
        record = json.loads(line)
        print(record)

{'ts': '2025-11-12T11:09:39Z', 'E': -771.244, 'li_on': [51, 58], 'mn3_on': [65, 68], 'n_li': 2, 'n_mn3': 2, 'cfg': '2407cd0e53eca32f', 'tags': {'status': 'INCUMBENT'}}
{'ts': '2025-11-12T11:09:39Z', 'E': -785.214, 'li_on': [28, 51], 'mn3_on': [65, 68], 'n_li': 2, 'n_mn3': 2, 'cfg': '18d3cfd95b7d8762', 'tags': {'status': 'INCUMBENT'}}
{'ts': '2025-11-12T11:09:39Z', 'E': -785.579, 'li_on': [28, 51], 'mn3_on': [65, 87], 'n_li': 2, 'n_mn3': 2, 'cfg': '96bb99b901ac1aef', 'tags': {'status': 'INCUMBENT'}}
{'ts': '2025-11-12T11:09:40Z', 'E': -851.271, 'li_on': [28, 55], 'mn3_on': [65, 86], 'n_li': 2, 'n_mn3': 2, 'cfg': '4641153f4ac829bc', 'tags': {'status': 'INCUMBENT'}}
{'ts': '2025-11-12T11:09:40Z', 'E': -854.281, 'li_on': [26, 55], 'mn3_on': [65, 75], 'n_li': 2, 'n_mn3': 2, 'cfg': '691abb9ebada7da9', 'tags': {'status': 'INCUMBENT'}}
{'ts': '2025-11-12T11:09:40Z', 'E': -854.641, 'li_on': [26, 55], 'mn3_on': [65, 67], 'n_li': 2, 'n_mn3': 2, 'cfg': '53033daee588a84a', 'tags': {'status': 'INCUM

In [None]:
# initial_structure = Structure.from_file('data/delithiated_tmp.cif')

# initial_grid = generate_filtered_grid(initial_structure, N_initial_grid=N_initial_grid, min_dist_grid=min_dist_grid)
# structure = join_structure_grid(initial_structure,initial_grid)

# grid = copy.deepcopy(initial_grid)
# num_iterations = 1
# for i in range(num_iterations):
#     print(f'************ Begin Iteration {i} ************')

#     mn_atom_indices = np.where(np.array(structure.atomic_numbers)==24)[0]
#     solver_params = {"time":180, "workers":8, "seed":42}

#     hashes = init_run_store(
#         output_dir=output_dir,
#         initial_structure=initial_structure,
#         li_sites=li_sites,
#         mn_sites=mn_sites,
#         initial_grid_cart=grid,          # your grid in CARTESIAN
#         mn_atom_indices=mn_atom_indices, # you provide this mapping
#         QUBO_ut=QUBO,
#         SCALE=SCALE,
#         solver_params=solver_params,
#     )

#     N_structures_opt_input = copy.deepcopy(N_structures_opt)
#     output_dir = f'output_folder_{i}'

#     QUBO, li_indices, mn_indices = build_QUBO(structure, threshold_li=threshold_li, prox_penalty=prox_penalty)
    
#     site_options, var2siteopt, li_sites, mn_sites = build_site_option_maps_from_indices(li_indices, mn_indices)

#     model = cp_model.CpModel()
#     build_x_vars_and_onehot(model, site_options)
    
#     model, x, site_options, var2siteopt, li_sites, mn_sites = cpsat_core_from_indices(li_indices, mn_indices, N_li=2)
    
#     SCALE, info = add_ut_qubo_objective(model, x, var2siteopt, QUBO)

#     cb = None  # ← no callback

#     solver = cp_model.CpSolver()

#     # Version-safe, portable knobs
#     solver.parameters.max_time_in_seconds = 180
#     solver.parameters.num_search_workers = 8
#     solver.parameters.random_seed = 42
#     solver.parameters.log_search_progress = True
#     if hasattr(solver.parameters, "use_lns"):
#         solver.parameters.use_lns = True

#     # Solve (with or without callback)
#     status = solver.Solve(model, cb) if cb is not None else solver.Solve(model)

#     print("Status:", solver.StatusName(status))

#     if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
#         # Best assignment
#         for s, opts in site_options.items():
#             chosen = next(a for a in opts if solver.Value(x[(s, a)]) == 1)
#             print(f"Site {s}: {chosen}")
#         try:
#             print("Best energy (eV):", solver.ObjectiveValue() / SCALE)
#         except Exception:
#             pass

#         # If using a callback, how many incumbents were saved
#         if cb is not None:
#             print("Saved solutions:", cb.solution_count)
#     else:
#         print("No feasible solution found.")
    
#     # --- after Solve() succeeds ---
#     if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
#         assignment = {s: next(a for a in opts if solver.Value(x[(s, a)]) == 1)
#                     for s, opts in site_options.items()}
#         energy_ev = solver.ObjectiveValue() / SCALE
#         cfg_hash = append_incumbent(
#             output_dir=output_dir,
#             assignment=assignment,
#             energy_ev=energy_ev,
#             li_sites=li_sites,
#             mn_sites=mn_sites,
#             tags={"status": solver.StatusName(status)}
#         )
#         print("Saved incumbent cfg:", cfg_hash)
    

************ Begin Iteration 0 ************


Buckingham matrix: 100%|██████████| 137/137 [00:09<00:00, 14.96it/s]


65
48
65
48

Starting CP-SAT solver v9.14.6206
Parameters: random_seed: 42 max_time_in_seconds: 180 log_search_progress: true num_search_workers: 8 use_lns: true

Initial optimization model '': (model_fingerprint: 0xdb2c0d68009930ce)
#Variables: 6'482 (#bools: 6'417 in objective) (88 primary variables)
  - 6'482 Booleans in [0,1]
#kLinear2: 12'697
#kLinear3: 6'304
#kLinearN: 1 (#terms: 65)

Starting presolve at 0.01s
  1.66e-03s  0.00e+00d  [DetectDominanceRelations] 
  2.83e-02s  0.00e+00d  [PresolveToFixPoint] #num_loops=2 #num_dual_strengthening=1 
  2.80e-05s  0.00e+00d  [ExtractEncodingFromLinear] 
  3.10e-04s  0.00e+00d  [DetectDuplicateColumns] 
  1.63e-03s  0.00e+00d  [DetectDuplicateConstraints] 
[Symmetry] Graph for symmetry has 25'573 nodes and 56'979 arcs.
[Symmetry] Symmetry computation done. time: 0.003001 dtime: 0.00544876
[SAT presolve] num removable Booleans: 0 / 6393
[SAT presolve] num trivial clauses: 0
[SAT presolve] [0s] clauses:18912 literals:44128 vars:6393 one_s