<a href="https://colab.research.google.com/github/jamessutton600613-png/GC/blob/main/Untitled236.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Reconstructed from Untitled235.ipynb.txt
# UTC: 2025-11-01T11:26:32.072167Z

# ======== CaMn4O5·(W3,W4) — Robust cluster pick + μ-oxo + waters + TDSE (H/I/J) ========
import os, numpy as np, matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import Image, display
import sys, subprocess
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree
from scipy.cluster.vq import kmeans2

SEARCH = ["/content", "/content/sample_data", "./", "./sample_data"]
NAMES  = ["8F4H.cif","8F4I.cif","8F4J.cif"]
FOUND={}
for nm in NAMES:
    for root in SEARCH:
        p=os.path.join(root,nm)
        if os.path.isfile(p): FOUND[nm]=p; break
if not FOUND: raise FileNotFoundError("Place 8F4H/8F4I/8F4J in /content or ./sample_data")
OUT = "/content" if os.path.isdir("/content") else "."

def read_atoms(path):
    try: st = gemmi.read_structure(path)
    except: st = gemmi.make_structure_from_block(gemmi.cif.read_file(path)[0])
    if hasattr(st,"remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    m=st[0]
    coords=[]; elem=[]; resn=[]
    for ch in m:
        for r in ch:
            for a in r:
                p=a.pos
                coords.append([p.x,p.y,p.z]); elem.append(a.element.name.upper()); resn.append(r.name.upper())
    return np.array(coords, np.float32), np.array(elem, object), np.array(resn, object)

# --- pick the CaMn4 cluster via k-means on Mn positions (choose tightest 4) ---
def pick_cluster(coords, elem):
    mn_xyz = coords[elem=="MN"]
    if len(mn_xyz)<4: raise AssertionError("Need ≥4 Mn in CIF")
    # brute: try kmeans k=4…6, take the cluster of 4 with smallest pairwise RMS
    best=None
    for k in range(4, min(6, len(mn_xyz))+1):
        c, lab = kmeans2(mn_xyz, k, minit="points")
        for j in range(k):
            grp = mn_xyz[lab==j]
            if len(grp)==4:
                # compactness score
                d = np.linalg.norm(grp[:,None,:]-grp[None,:,:], axis=-1)
                score = d[np.triu_indices(4,1)].mean()
                if best is None or score<best[0]: best=(score, grp.mean(0), grp)
    if best is None:  # fallback: 4 nearest-neighbour chain
        from itertools import combinations
        dmat = np.linalg.norm(mn_xyz[:,None,:]-mn_xyz[None,:,:], axis=-1)
        best=(1e9,None,None)
        for idx in combinations(range(len(mn_xyz)),4):
            g = mn_xyz[list(idx)]
            score = np.linalg.norm(g[:,None,:]-g[None,:,:], axis=-1)
            score = score[np.triu_indices(4,1)].mean()
            if score<best[0]: best=(score, g.mean(0), g)
    center = best[1]
    # pick Ca nearest to that center
    ca_xyz = coords[elem=="CA"]
    if len(ca_xyz)==0: raise AssertionError("No Ca found")
    ca = ca_xyz[np.argmin(np.linalg.norm(ca_xyz-center, axis=1))]
    # indices of the chosen 4 Mn and that Ca
    sel_mn = [i for i,(e,x) in enumerate(zip(elem,coords)) if e=="MN" and np.min(np.linalg.norm(best[2]-x,axis=1))<1e-3]
    sel_ca = int(np.where((coords==ca).all(1))[0][0])
    return sel_mn, sel_ca

# --- find μ-oxo O (tight window, with auto-relax if empty) ---
def find_mu_oxo(coords, elem, sel_mn, r_min=1.80, r_max=2.20):
    O_idx = [i for i,e in enumerate(elem) if e=="O"]
    mu=[]
    for iO in O_idx:
        d = np.linalg.norm(coords[sel_mn]-coords[iO], axis=1)
        if ( (d>=r_min)&(d<=r_max) ).sum()>=2:
            mu.append(iO)
    # relax window if none found
    if len(mu)<3:
        for widen in [(1.75,2.30),(1.70,2.40)]:
            mu=[]
            for iO in O_idx:
                d = np.linalg.norm(coords[sel_mn]-coords[iO], axis=1)
                if ( (d>=widen[0])&(d<=widen[1]) ).sum()>=2:
                    mu.append(iO)
            if len(mu)>=3: break
    return sorted(set(mu))

# --- pick W3/W4: O close to Ca and near (not bonded) to at least one Mn ---
def find_W34(coords, elem, sel_mn, sel_ca, r_ca=(2.30,2.70), r_mn=(2.70,3.40)):
    O_idx = [i for i,e in enumerate(elem) if e=="O"]
    ca_pos = coords[sel_ca]
    cand=[]
    for iO in O_idx:
        d_ca = np.linalg.norm(coords[iO]-ca_pos)
        d_mn = np.min(np.linalg.norm(coords[sel_mn]-coords[iO], axis=1))
        if r_ca[0]<=d_ca<=r_ca[1] and r_mn[0]<=d_mn<=r_mn[1]:
            cand.append((iO, d_ca, d_mn))
    if len(cand)<2:
        # relax gently
        for rca in [(2.2,2.9),(2.1,3.1)]:
            cand=[]
            for iO in O_idx:
                d_ca = np.linalg.norm(coords[iO]-ca_pos)
                d_mn = np.min(np.linalg.norm(coords[sel_mn]-coords[iO], axis=1))
                if rca[0]<=d_ca<=rca[1] and 2.6<=d_mn<=3.6:
                    cand.append((iO,d_ca,d_mn))
            if len(cand)>=2: break
    cand.sort(key=lambda x:(abs(x[1]-2.45), x[2]))
    return [c[0] for c in cand[:2]]

# --- build graph & TDSE (electron) ---
PAIR_RC=3.45; BETA=1.7; LAM=1.2; DT=0.5; STEPS=700; FRAME_EVERY=3
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, CA=+0.60)

def angle_boost(X, iO, mn_list):
    if len(mn_list)<2: return 1.0
    d = [(j, float(np.linalg.norm(X[j]-X[iO]))) for j in mn_list]; d.sort(key=lambda x:x[1])
    a,b = d[0][0], d[1][0]
    v1, v2 = X[a]-X[iO], X[b]-X[iO]
    c = np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12)
    return 1.0 + LAM*(c*c)

def build_case(coords, elem):
    sel_mn, sel_ca = pick_cluster(coords, elem)
    mu = find_mu_oxo(coords, elem, sel_mn)
    W  = find_W34(coords, elem, sel_mn, sel_ca)
    core = sorted(set(sel_mn + [sel_ca] + mu + W))
    X = coords[core].copy()
    tags=[]
    for i in core:
        if i in sel_mn: tags.append("MN")
        elif i==sel_ca: tags.append("CA")
        elif i in W: tags.append("Ow")
        elif i in mu: tags.append("Omu")
        else: tags.append(elem[i])
    # edges
    tree = cKDTree(X); E=[]
    for i in range(len(X)):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j<=i: continue
            if np.linalg.norm(X[i]-X[j])<=0.5: continue
            E.append((i,j))
    # Hamiltonian
    H = np.zeros((len(X),len(X)), complex)
    mn_list=[i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E:
        ti,tj = tags[i], tags[j]; r = float(np.linalg.norm(X[i]-X[j])); t=0.0
        if (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(X, j, mn_list) if (ti=="MN" and tj=="Omu") else \
                    angle_boost(X, i, mn_list) if (tj=="MN" and ti=="Omu") else 1.0
            t = np.exp(-BETA*r)*boost
        elif (ti=="CA" and tj=="Ow") or (tj=="CA" and ti=="Ow"):
            t = 0.7*np.exp(-BETA*r)
        elif ti.startswith("O") and tj.startswith("O"):
            t = 0.08*np.exp(-BETA*r)
        if t!=0.0: H[i,j]=t; H[j,i]=t
    for i,tg in enumerate(tags): H[i,i]=eps.get(tg,0.0)
    # diagnostics (must have bridges!)
    n_MnOmu = sum(1 for (i,j) in E if {"MN","Omu"}=={tags[i],tags[j]} and abs(H[i,j])>0)
    n_CaOw  = sum(1 for (i,j) in E if {"CA","Ow"}=={tags[i],tags[j]} and abs(H[i,j])>0)
    if n_MnOmu==0 or n_CaOw==0:
        raise RuntimeError(f"Graph disconnected: Mn–Oμ={n_MnOmu}, Ca–Ow={n_CaOw}. "
                           f"Relax finders or check CIF.")
    return X, tags, E, H

def TDSE(H, start, steps=STEPS, dt=DT):
    N=H.shape[0]; psi=np.zeros(N, complex); psi[start]=1.0
    frames=[]; times=[]
    for n in range(steps):
        E,V = np.linalg.eigh(H)
        U = V @ np.diag(np.exp(-1j*E*dt)) @ V.conj().T
        psi = U @ psi; psi/= (np.linalg.norm(psi)+1e-12)
        if n%FRAME_EVERY==0:
            frames.append((np.abs(psi)**2).astype(np.float32)); times.append(n*dt)
    return np.array(frames), np.array(times)

def PR(p): s1=p.sum(-1); s2=(p**2).sum(-1); return (s1*s1)/(s2+1e-12)
def Jproxy(E,H,P,tags):
    bridges=[(i,j) for (i,j) in E if abs(H[i,j])>0 and (
        {"MN","Omu"}=={tags[i],tags[j]} or {"CA","Ow"}=={tags[i],tags[j]})]
    out=[]
    for p in P:
        out.append(np.mean([abs(H[i,j])*np.sqrt(float(p[i]*p[j])) for (i,j) in bridges]))
    return np.array(out)

results={}
for nm,path in FOUND.items():
    C,ELEM,RESN = read_atoms(path)
    X,tags,E,H  = build_case(C,ELEM)
    start = next((i for i,t in enumerate(tags) if t=="Ow"), 0)
    P,T = TDSE(H, start)
    PRv = PR(P); Jv = Jproxy(E,H,P,tags)
    # Save per-CIF
    pr_png  = os.path.join(OUT, f"cubane_{nm}_PR.png")
    cur_png = os.path.join(OUT, f"cubane_{nm}_currents.png")
    plt.figure(figsize=(6,3.6)); plt.plot(T,PRv); plt.xlabel("time (fs)"); plt.ylabel("PR"); plt.title(f"PR — {nm}")
    plt.tight_layout(); plt.savefig(pr_png, dpi=170); plt.close()
    plt.figure(figsize=(6,3.6)); plt.plot(T,Jv); plt.xlabel("time (fs)"); plt.ylabel("current proxy"); plt.title(f"Bridges ⟨|J|⟩ — {nm}")
    plt.tight_layout(); plt.savefig(cur_png, dpi=170); plt.close()
    print(f"[{nm}] nodes={len(X)}  edges={len(E)}  Mn–Oμ={sum(1 for (i,j) in E if {'MN','Omu'}=={tags[i],tags[j]})}  Ca–Ow={sum(1 for (i,j) in E if {'CA','Ow'}=={tags[i],tags[j]})}")
    print("Saved:", pr_png, cur_png)
    # Small GIF of |psi|^2
    fig = plt.figure(figsize=(6.0,5.2)); ax = fig.add_subplot(111, projection="3d")
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([]); ax.grid(False)
    fig.patch.set_facecolor("#0E0E10"); ax.set_facecolor("#0E0E10")
    mn,mx=X.min(0),X.max(0); pad=1.0
    ax.set_xlim(mn[0]-pad,mx[0]+pad); ax.set_ylim(mn[1]-pad,mx[1]+pad); ax.set_zlim(mn[2]-pad,mx[2]+pad)
    def col(t): return {"MN":"#BA68C8","CA":"#66BB6A","Omu":"#EF5350","Ow":"#00E5FF"}.get(t,"#BDBDBD")
    def rad(t): return {"MN":1.0,"CA":0.9,"Omu":0.75,"Ow":0.65}.get(t,0.55)
    for (i,j) in E:
        if {"MN","Omu"}=={tags[i],tags[j]} or {"CA","Ow"}=={tags[i],tags[j]}:
            ax.plot([X[i,0],X[j,0]],[X[i,1],X[j,1]],[X[i,2],X[j,2]], lw=2.6, color="#6C757D", alpha=0.95)
    stat=ax.scatter(X[:,0],X[:,1],X[:,2], s=[240*rad(t) for t in tags], c=[col(t) for t in tags], alpha=0.35, depthshade=False)
    dyn = ax.scatter(X[:,0],X[:,1],X[:,2], s=120, c="#FFD54F", alpha=0.98, depthshade=False)
    ax.view_init(18,40)
    def upd(f):
        p=P[f]; p=p/(p.max()+1e-12)
        dyn.set_sizes(120+1800*np.sqrt(p))
        ax.set_title(f"|ψ|² — {nm} — t={T[f]:.2f} fs", color="w"); return [dyn,stat]
    gif = os.path.join(OUT, f"cubane_{nm}.gif")
    animation.FuncAnimation(fig, upd, frames=len(P), interval=30, blit=False)\
             .save(gif, writer=animation.PillowWriter(fps=20)); plt.close(fig)
    print("Saved GIF:", gif)
    results[nm]=dict(T=T, PR=PRv, J=Jv)

# overlays
plt.figure(figsize=(7.5,3.8))
for nm in results: plt.plot(results[nm]["T"], results[nm]["PR"], label=nm)
plt.xlabel("time (fs)"); plt.ylabel("PR"); plt.title("PR overlay"); plt.legend(); plt.tight_layout()
ov1=os.path.join(OUT,"cubane_PR_overlay.png"); plt.savefig(ov1, dpi=170); plt.close()

plt.figure(figsize=(7.5,3.8))
for nm in results: plt.plot(results[nm]["T"], results[nm]["J"], label=nm)
plt.xlabel("time (fs)"); plt.ylabel("current proxy"); plt.title("Bridge current overlay"); plt.legend(); plt.tight_layout()
ov2=os.path.join(OUT,"cubane_currents_overlay.png"); plt.savefig(ov2, dpi=170); plt.close()

display(Image(ov1)); display(Image(ov2))
# =======================================================================================

# ================= CaMn4O5·2H2O electron TDSE (H/D mimicry, one-cell all-in-one) =================
# - Reads 8F4H/I/J CIFs
# - Identifies Mn4/Ca/μ-oxo O and picks two water O (W3/W4 proxies); synthesizes their H atoms
# - Builds tight-binding H with simple distance/angle physics
# - Time-evolves ψ with a small global "breathing" + local isotope gating at W3/W4 (H vs D)
# - Saves per-CIF plots (PR, currents) + GIF, then H vs D overlays
# -----------------------------------------------------------------------------------------------
import os, sys, subprocess, numpy as np, matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import Image, display

# ---------- config ----------
CIFS = {
    "8F4H.cif": "/content/8F4H.cif",
    "8F4I.cif": "/content/8F4I.cif",
    "8F4J.cif": "/content/8F4J.cif",
}
OUTDIR = "/content" if os.path.isdir("/content") else "."
DT, N_STEPS = 0.5, 700          # fs, steps
FRAME_EVERY = 3                 # store every Nth frame
PAIR_RC = 3.4                   # Å — neighbor cutoff for graph display/edges
# Couplings / on-sites (dimensionless)
t_dp, beta, lam = 1.0, 1.0, 0.5           # Mn–O bridge, exp decay, μ-oxo angle boost
t_OO = 0.1                                 # weak O–O
t_OH, t_HB = 3.0, 0.8                      # O–H covalent; O···H H-bond
HBOND_MIN, HBOND_MAX = 1.2, 2.6            # Å
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)
# Global breathing of Mn–O network
BREATH_ON, BREATH_A, BREATH_T_FS = True, 0.08, 120.0

# ---------- deps ----------
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree

# ---------- isotope mapping (edit if you like) ----------
ISO_BY_CIF = {"8F4H.cif":"H", "8F4I.cif":"H", "8F4J.cif":"D"}

def isotope_params(cif_basename):
    iso = ISO_BY_CIF.get(cif_basename, "H")
    if iso == "H":
        return dict(OH_scale=1.00, HB_scale=1.00, breath_A=BREATH_A, breath_T=BREATH_T_FS,
                    local_A=1.00, local_Tfac=1.00, eps_H_shift=0.0, tag="H-like")
    else:  # D-like: ω↓, amplitude↓
        return dict(OH_scale=0.92, HB_scale=0.85, breath_A=BREATH_A*0.90, breath_T=BREATH_T_FS,
                    local_A=0.70, local_Tfac=1.41421356, eps_H_shift=-0.02, tag="D-like")

# ---------- CIF helpers ----------
def read_atoms(cif_path):
    try: st = gemmi.read_structure(cif_path)
    except: st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st,"remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    m = st[0]
    atoms=[]
    for ch in m:
        for res in ch:
            for at in res:
                p=at.pos
                atoms.append((np.array([p.x,p.y,p.z],np.float32),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

def pick_mu_oxo_and_waters(coords, elem, mn_idx, ca_idx):
    O_idx = [i for i,e in enumerate(elem) if e=="O"]
    # μ-oxo: O within 2.7 Å of >=2 Mn
    mu_oxo=[]
    for i in O_idx:
        d = np.linalg.norm(coords[mn_idx] - coords[i], axis=1)
        if (d<=2.7).sum()>=2: mu_oxo.append(i)
    # W3/W4 proxies: the two closest O (not μ-oxo) to Mn/Ca centroid
    center = coords[(mn_idx + ca_idx)].mean(axis=0)
    water_candidates = [i for i in O_idx if i not in mu_oxo]
    if not water_candidates:
        raise RuntimeError("No water oxygen candidates found.")
    dW = [(i, float(np.linalg.norm(coords[i]-center))) for i in water_candidates]
    dW.sort(key=lambda x:x[1])
    W_ox = [dW[0][0], dW[1][0]]
    return mu_oxo, W_ox

def synthesize_water_Hs(coords, W_ox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
    X = coords.copy()
    H_idx = []
    phi = np.deg2rad(angle_deg/2.0)
    for iO in W_ox:
        O = X[iO]
        vec = np.zeros(3, float)
        for j in neighbor_pool_idx:
            r = X[j]-O; rn = np.linalg.norm(r)
            if 1e-8 < rn < influence_cut: vec += r/rn
        if np.linalg.norm(vec) < 1e-6: vec = np.array([1.0,0.0,0.0])
        bhat = vec / (np.linalg.norm(vec)+1e-12)
        ref = np.array([1.0,0.0,0.0]) if abs(bhat[0])<0.9 else np.array([0.0,1.0,0.0])
        u1 = np.cross(bhat, ref);
        if np.linalg.norm(u1)<1e-6: u1=np.array([0.0,0.0,1.0])
        u1 /= (np.linalg.norm(u1)+1e-12)
        H1 = O + d_OH*(np.cos(phi)*bhat + np.sin(phi)*u1)
        H2 = O + d_OH*(np.cos(phi)*bhat - np.sin(phi)*u1)
        X = np.vstack([X, H1, H2])
        H_idx += [len(X)-2, len(X)-1]
    return X, H_idx

# ---------- build graph + Hamiltonian ----------
def angle_boost(iO, mn_list, X):
    near = []
    for j in mn_list:
        d = np.linalg.norm(X[j]-X[iO])
        if d<=2.7: near.append(j)
        if len(near)>=2: break
    if len(near)>=2:
        a,b = near[:2]
        v1 = X[a]-X[iO]; v2 = X[b]-X[iO]
        c = np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12)
        return 1.0 + lam*(c*c)
    return 1.0

def build_min_graph(coords_ext, elem, mn_idx, ca_idx, mu_oxo, W_ox, H_syn_idx, iso):
    # node set (order fixed)
    node_core = sorted(set(mn_idx + ca_idx + mu_oxo + W_ox), key=int)
    ids = node_core + list(H_syn_idx)
    X = coords_ext[ids].copy()
    # tags
    tags=[]
    for i in ids:
        if i in H_syn_idx: tags.append("H")
        else:
            e = elem[i]
            if e=="MN": tags.append("MN")
            elif e=="CA": tags.append("CA")
            elif e=="O": tags.append("Omu" if i in mu_oxo else ("Ow" if i in W_ox else "O"))
            else: tags.append(e)
    # local map for W region
    g2l = {g:i for i,g in enumerate(ids)}
    W_region = set(g2l[i] for i in (set(W_ox)|set(H_syn_idx)) if i in g2l)
    # edges and static H
    N=len(X); tree=cKDTree(X)
    E=[]
    for i in range(N):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j<=i: continue
            if np.linalg.norm(X[i]-X[j])<=0.5: continue
            E.append((i,j))
    H = np.zeros((N,N), complex)
    mn_list=[i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E:
        rij = float(np.linalg.norm(X[i]-X[j])); ti,tj = tags[i], tags[j]
        val = 0.0
        if {ti,tj}=={"Ow","H"} and rij<=1.25:
            val = t_OH * iso["OH_scale"]
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            val = t_HB * iso["HB_scale"] if (HBOND_MIN <= rij <= HBOND_MAX) else 0.0
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = 1.0
            if ti=="MN" and tj=="Omu": boost = angle_boost(j, mn_list, X)
            if tj=="MN" and ti=="Omu": boost = angle_boost(i, mn_list, X)
            val = t_dp * np.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            val = t_OO*np.exp(-beta*rij)
        if val!=0.0: H[i,j]=val; H[j,i]=val
    for i,tg in enumerate(tags):
        H[i,i] = eps.get(tg,0.0) + (iso["eps_H_shift"] if tg=="H" else 0.0)
    return X,tags,E,H,W_region

# ---------- TDSE ----------
def tdse_time_dep(H_static, E_list, tags, W_region, iso, steps=N_STEPS, dt=DT):
    N = H_static.shape[0]
    # start at water-O if present else Mn
    try: s = next(i for i,t in enumerate(tags) if t=="Ow")
    except StopIteration: s = next(i for i,t in enumerate(tags) if t=="MN")
    psi = np.zeros(N, complex); psi[s]=1.0
    frames=[]; times=[]
    # global breathing target: Mn–O edges
    mn_o_pairs = [(i,j) for (i,j) in E_list
                  if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O"))]
    local_pairs = [(i,j) for (i,j) in E_list if (i in W_region or j in W_region)]
    for n in range(steps):
        if BREATH_ON:
            Ht = H_static.copy()
            # light global breathing
            g = 1.0 + iso["breath_A"]*np.sin(2*np.pi*(n*dt)/iso["breath_T"])
            for (i,j) in mn_o_pairs:
                Ht[i,j] *= g; Ht[j,i] *= g
            # local isotope breathing on W3/W4 neighborhood
            lgate = 1.0 + iso["local_A"]*np.sin(2*np.pi*(n*dt)/(iso["local_Tfac"]*BREATH_T_FS))
            for (i,j) in local_pairs:
                Ht[i,j] *= lgate; Ht[j,i] *= lgate
        else:
            Ht = H_static
        E,V = np.linalg.eigh(Ht)
        U = V @ np.diag(np.exp(-1j*E*dt)) @ V.conj().T
        psi = U @ psi; psi /= (np.linalg.norm(psi)+1e-12)
        if n % FRAME_EVERY == 0:
            frames.append((np.abs(psi)**2).astype(np.float32)); times.append(n*dt)
    return np.array(frames), np.array(times)

# ---------- diagnostics/metrics ----------
def PR(p): s1=p.sum(-1); s2=(p**2).sum(-1); return (s1*s1)/(s2+1e-12)
def edge_current_proxy(E, H, Pframes):
    edges = [(i,j) for (i,j) in E if abs(H[i,j])>0]
    J=[]
    for p in Pframes:
        vals=[abs(H[i,j])*np.sqrt(p[i]*p[j]) for (i,j) in edges]
        J.append(np.mean(vals) if vals else 0.0)
    return np.array(J)

# ---------- per-CIF run ----------
results = {}  # name -> dict(PR, J, T, tag)
for name, path in CIFS.items():
    if not os.path.isfile(path):
        print(f"[skip] Missing {path}")
        continue
    ISO = isotope_params(name)
    atoms = read_atoms(path)
    coords = np.array([a[0] for a in atoms], np.float32)
    elem   = np.array([a[1] for a in atoms], object)

    mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
    ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
    assert len(mn_idx)>=4, f"{name}: need >=4 Mn"

    mu_oxo, W_ox = pick_mu_oxo_and_waters(coords, elem, mn_idx, ca_idx)
    neighbor_pool = mn_idx + mu_oxo + ca_idx
    coords_ext, H_syn_idx = synthesize_water_Hs(coords, W_ox, neighbor_pool)

    X,tags,E,H,W_region = build_min_graph(coords_ext, elem, mn_idx, ca_idx, mu_oxo, W_ox, H_syn_idx, ISO)
    P,T = tdse_time_dep(H, E, tags, W_region, ISO)

    # metrics
    pr = PR(P); J = edge_current_proxy(E,H,P)
    results[name] = dict(PR=pr, J=J, T=T, tag=ISO["tag"])

    # quick plots
    plt.figure(figsize=(6.4,3.6))
    plt.plot(T, pr)
    plt.xlabel("time (fs)"); plt.ylabel("PR")
    plt.title(f"PR — {name} [{ISO['tag']}]")
    fn = os.path.join(OUTDIR, f"cubane_{name}_PR.png"); plt.tight_layout(); plt.savefig(fn, dpi=160); plt.close()
    print("Saved:", fn)

    plt.figure(figsize=(6.4,3.6))
    plt.plot(T, J)
    plt.xlabel("time (fs)"); plt.ylabel("current proxy")
    plt.title(f"Bridge current — {name} [{ISO['tag']}]")
    fn = os.path.join(OUTDIR, f"cubane_{name}_currents.png"); plt.tight_layout(); plt.savefig(fn, dpi=160); plt.close()
    print("Saved:", fn)

    # small GIF (scatter with breathing sizes)
    fig = plt.figure(figsize=(6.2,5.4)); ax = fig.add_subplot(111, projection="3d")
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([]); ax.grid(False)
    fig.patch.set_facecolor("#0E0E10"); ax.set_facecolor("#0E0E10")
    mn_lim, mx_lim = X.min(0), X.max(0); pad=1.0
    ax.set_xlim(mn_lim[0]-pad, mx_lim[0]+pad); ax.set_ylim(mn_lim[1]-pad, mx_lim[1]+pad); ax.set_zlim(mn_lim[2]-pad, mx_lim[2]+pad)
    tree=cKDTree(X); Edisp=[]
    for i in range(len(X)):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j>i and np.linalg.norm(X[i]-X[j])>0.5: Edisp.append((i,j))
    for (i,j) in Edisp:
        ax.plot([X[i,0],X[j,0]],[X[i,1],X[j,1]],[X[i,2],X[j,2]], lw=2.0, alpha=0.85, color="#777")
    colmap={"MN":"#BA68C8","CA":"#66BB6A","Omu":"#EF5350","Ow":"#00E5FF","H":"#FFD54F"}
    radmap={"MN":0.90,"CA":0.80,"Omu":0.70,"Ow":0.65,"H":0.45}
    cols=[colmap.get(t,"#BDBDBD") for t in tags]; sizes=[170*radmap.get(t,0.55) for t in tags]
    stat = ax.scatter(X[:,0],X[:,1],X[:,2], s=sizes, c=cols, alpha=0.95)
    dyn  = ax.scatter(X[:,0],X[:,1],X[:,2], s=120, c="#FFD54F", alpha=0.90)
    ax.view_init(18,40)
    def upd(f):
        p = P[f]; p = p/(p.max()+1e-12); dyn.set_sizes(120 + 1400*p)
        ax.set_title(f"|ψ|² on Mn₄CaO₅ — t={T[f]:.2f} fs   {name} [{ISO['tag']}]", color="w")
        return [dyn, stat]
    gif = os.path.join(OUTDIR, f"cubane_{name}.gif")
    ani = animation.FuncAnimation(fig, upd, frames=len(P), interval=30, blit=False)
    ani.save(gif, writer=animation.PillowWriter(fps=20)); plt.close(fig)
    print("Saved GIF:", gif)

# ---------- overlays (compare H-like vs D-like) ----------
if results:
    # participation ratio
    plt.figure(figsize=(7.8,3.8))
    for nm, dat in results.items():
        plt.plot(dat["T"], dat["PR"], label=f"{nm} [{dat['tag']}]")
    plt.xlabel("time (fs)"); plt.ylabel("PR"); plt.title("PR overlay (H-like vs D-like)")
    plt.legend(ncol=3); plt.tight_layout()
    fn=os.path.join(OUTDIR,"cubane_PR_overlay_isotope.png"); plt.savefig(fn, dpi=160); plt.close()
    print("Saved:", fn)

    # current proxy
    plt.figure(figsize=(7.8,3.8))
    for nm, dat in results.items():
        plt.plot(dat["T"], dat["J"], label=f"{nm} [{dat['tag']}]")
    plt.xlabel("time (fs)"); plt.ylabel("current proxy"); plt.title("Bridge current overlay (H-like vs D-like)")
    plt.legend(ncol=3); plt.tight_layout()
    fn=os.path.join(OUTDIR,"cubane_currents_overlay_isotope.png"); plt.savefig(fn, dpi=160); plt.close()
    print("Saved:", fn)

# show a quick gallery
for nm in sorted(results.keys()):
    display(Image(os.path.join(OUTDIR, f"cubane_{nm}_PR.png")))
    display(Image(os.path.join(OUTDIR, f"cubane_{nm}_currents.png")))
display(Image(os.path.join(OUTDIR,"cubane_PR_overlay_isotope.png")))
display(Image(os.path.join(OUTDIR,"cubane_currents_overlay_isotope.png")))
# ================================================================================================

# --- REPLACE the single line:
# gate = 1.0 + BREATH_A*np.sin(2*np.pi*(n*dt)/BREATH_T_FS)
# --- WITH this block (inside tdse_time_dep, right before we scale Mn–O edges):

def _d_window(t, centers=(80.0, 170.0, 260.0), width=18.0):
    # returns 1.0 normally (H-like), ramps toward D-like (<=D_DAMP) in short windows
    # smooth Tukey-ish dip centered at `centers` (fs); `width` is full width (fs)
    D_DAMP = 0.55  # how "heavy": 1.0 = no damping, smaller = more D-like
    y = 1.0
    for c in centers:
        x = abs(t - c)
        if x < width*0.5:
            # cosine well from 1.0 (edges) to D_DAMP (center)
            y = min(y, D_DAMP + (1.0 - D_DAMP)*0.5*(1.0 + np.cos(np.pi*x/(width*0.5))))
    return y

t_now = n*dt
# keep the breathing as before
breath = 1.0 + BREATH_A*np.sin(2*np.pi*t_now/BREATH_T_FS)
# apply brief D-like windows only on Mn–O edges (local isotope mimicry at W3/W4)
local = _d_window(t_now, centers=(85.0, 175.0, 265.0), width=22.0)

Ht = H_static.copy()
for (i,j) in mn_o_pairs:
    Ht[i,j] *= breath * local
    Ht[j,i] *= breath * local



# ================== Cubane TDSE with time-gated D-like windows (one cell) ==================
# H/I/J are run automatically. Outputs in /content (Colab) or cwd.
# ------------------------------------------------------------------------------------------
import os, sys, subprocess, numpy as np, matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import Image, display

# ---------- deps ----------
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree
import csv

# ---------- config ----------
CIF_PATHS = [p for p in ("/content/8F4H.cif","/content/8F4I.cif","/content/8F4J.cif") if os.path.isfile(p)]
if not CIF_PATHS: raise FileNotFoundError("Place 8F4H.cif, 8F4I.cif, 8F4J.cif in /content or adjust CIF_PATHS.")

OUT = "/content" if os.path.isdir("/content") else "."
DT, STEPS = 0.5, 720         # fs, steps  (~360 fs total)
PAIR_RC = 3.4                # Å
FRAME_EVERY = 3              # store every 3 fs

# Electronic couplings / on-sites (dimensionless)
t_dp, beta, lam = 1.0, 1.0, 0.5     # Mn–O bridge (angle-boosted on μ-oxo)
t_OH, t_HB = 3.0, 0.8               # O–H covalent; O···H hydrogen bond
HBOND_MIN, HBOND_MAX = 1.2, 2.6
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)

# Breathing (global gating of Mn–O edges)
BREATH_ON, BREATH_A, BREATH_T = True, 0.08, 120.0  # amplitude, period(fs)

# ---- NEW: time-gated D-like behavior on W3/W4 only ----
# Within each breathing period, for a window of length D_WIN, couplings touching W3/W4 are reduced.
D_FACTOR = 0.55             # <1 ⇒ heavier (D-like) damps vibronic bandwidth locally
D_WIN   = 24.0              # fs duration of each D-like gate within a period
PHASES  = {"H": 0.0, "I": 20.0, "J": 40.0}  # phase offset per CIF to mimic cycle staggering (fs)

# ---------- helpers ----------
def read_atoms(cif_path):
    try: st = gemmi.read_structure(cif_path)
    except: st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st,"remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    m = st[0]
    atoms=[]
    for ch in m:
        for res in ch:
            for at in res:
                p=at.pos
                atoms.append((np.array([p.x,p.y,p.z],np.float32),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

def identify_cubane_parts(coords, elem, resn):
    mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
    ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
    O_idx  = [i for i,e in enumerate(elem) if e=="O"]
    assert len(mn_idx)>=4, "Need at least 4 Mn."

    # μ-oxo: O near ≥2 Mn within 2.7 Å
    mu_oxo=[]
    for i in O_idx:
        d = np.linalg.norm(coords[mn_idx]-coords[i], axis=1)
        if (d<=2.7).sum()>=2: mu_oxo.append(i)

    # choose two water O (W3/W4 proxies): two O (not μ-oxo) closest to Mn/Ca centroid
    center = coords[(mn_idx + ca_idx)].mean(axis=0) if ca_idx else coords[mn_idx].mean(axis=0)
    water_candidates = [i for i in O_idx if i not in mu_oxo]
    if not water_candidates: raise RuntimeError("No water oxygen candidates.")
    dW = [(i, float(np.linalg.norm(coords[i]-center))) for i in water_candidates]
    dW.sort(key=lambda x:x[1])
    W_ox = [dW[0][0], dW[1][0]]
    return mn_idx, ca_idx, mu_oxo, W_ox

def synthesize_water_Hs(coords, W_ox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
    X = coords.copy(); H_idx=[]
    phi = np.deg2rad(angle_deg/2.0)
    for iO in W_ox:
        O = X[iO]
        vec = np.zeros(3,float)
        for j in neighbor_pool_idx:
            r = X[j] - O; rn = np.linalg.norm(r)
            if 1e-8 < rn < influence_cut: vec += r/rn
        if np.linalg.norm(vec)<1e-6: vec = np.array([1.0,0.0,0.0])
        bhat = vec/np.linalg.norm(vec)
        ref = np.array([1.0,0.0,0.0]) if abs(bhat[0])<0.9 else np.array([0.0,1.0,0.0])
        u1 = np.cross(bhat, ref);
        if np.linalg.norm(u1)<1e-6: u1=np.array([0.0,0.0,1.0])
        u1 /= np.linalg.norm(u1)
        H1 = O + d_OH*(np.cos(phi)*bhat + np.sin(phi)*u1)
        H2 = O + d_OH*(np.cos(phi)*bhat - np.sin(phi)*u1)
        X = np.vstack([X, H1, H2]); H_idx += [len(X)-2, len(X)-1]
    return X, H_idx

def angle_boost(iO, mn_list, X):
    rs=[(j, np.linalg.norm(X[j]-X[iO])) for j in mn_list]
    near=[j for j,d in rs if d<=2.7]
    if len(near)>=2:
        a,b = near[:2]
        v1, v2 = X[a]-X[iO], X[b]-X[iO]
        c = np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12)
        return 1.0 + lam*(c*c)
    return 1.0

def build_min_graph(coords, elem):
    mn_idx, ca_idx, mu_oxo, W_ox = identify_cubane_parts(coords, elem, None)
    neighbor_pool_idx = mn_idx + mu_oxo + ca_idx
    Xext, Hsyn = synthesize_water_Hs(coords, W_ox, neighbor_pool_idx)
    node_core = sorted(set(mn_idx + ca_idx + mu_oxo + W_ox), key=int)
    ids = node_core + Hsyn
    X = Xext[ids].copy()
    # tags
    tags=[]
    for i in ids:
        if i in Hsyn: tags.append("H")
        else:
            e = elem[i]
            if e=="MN": tags.append("MN")
            elif e=="CA": tags.append("CA")
            elif e=="O": tags.append("Omu" if i in mu_oxo else ("Ow" if i in W_ox else "O"))
            else: tags.append(e)
    # useful masks
    idx = {g:[k for k,t in enumerate(tags) if t==g] for g in ("MN","CA","Omu","Ow","H")}
    return X, tags, idx, (mn_idx, ca_idx, mu_oxo, W_ox)

def build_edges_and_baseH(X, tags):
    N=len(X); tree=cKDTree(X); E=[]
    for i in range(N):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j<=i: continue
            if np.linalg.norm(X[i]-X[j])<=0.5: continue
            E.append((i,j))
    H0 = np.zeros((N,N), complex)
    mn_list=[i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E:
        ri,rj = X[i],X[j]; rij=float(np.linalg.norm(ri-rj))
        ti,tj = tags[i],tags[j]
        if {ti,tj}=={"Ow","H"} and rij<=1.25:
            t = t_OH
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            t = t_HB if (HBOND_MIN<=rij<=HBOND_MAX) else 0.0
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = 1.0
            if ti=="MN" and tj=="Omu": boost = angle_boost(j, mn_list, X)
            if tj=="MN" and ti=="Omu": boost = angle_boost(i, mn_list, X)
            t = t_dp * np.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            t = 0.1*np.exp(-beta*rij)
        else:
            t = 0.0
        if t!=0.0: H0[i,j]=t; H0[j,i]=t
    for i,tg in enumerate(tags): H0[i,i] = eps.get(tg,0.0)
    return E, H0

def tdse_gated(X, tags, E, H0, start_idx, phase_offset_fs):
    N = len(X); psi = np.zeros(N, complex); psi[start_idx]=1.0
    times=[]; frames=[]
    # cache lists
    mn_o_pairs = [(i,j) for (i,j) in E if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O"))]
    w3w4_touch = [(i,j) for (i,j) in E if (tags[i]=="Ow" or tags[j]=="Ow")]
    for n in range(STEPS):
        t_fs = n*DT
        Ht = H0.copy()
        # breathing on Mn–O edges
        if BREATH_ON:
            gate = 1.0 + BREATH_A*np.sin(2*np.pi*(t_fs)/BREATH_T)
            for (i,j) in mn_o_pairs:
                Ht[i,j] *= gate; Ht[j,i] *= gate
        # D-like window gating ONLY on edges touching W3/W4
        tau = (t_fs + phase_offset_fs) % BREATH_T
        if tau <= D_WIN:
            for (i,j) in w3w4_touch:
                Ht[i,j] *= D_FACTOR; Ht[j,i] *= D_FACTOR
        # unitary step
        Evals, V = np.linalg.eigh(Ht)
        U = V @ np.diag(np.exp(-1j*Evals*DT)) @ V.conj().T
        psi = U @ psi; psi /= (np.linalg.norm(psi)+1e-12)
        if n % FRAME_EVERY == 0:
            frames.append((np.abs(psi)**2).astype(np.float32))
            times.append(t_fs)
    return np.array(frames), np.array(times)

def group_pop(P, tags):
    groups = ["MN","Omu","Ow","H","CA"]
    idx = {g:[i for i,t in enumerate(tags) if t==g] for g in groups}
    labs=[g for g in groups if idx[g]]
    mat = np.stack([P[...,idx[g]].sum(-1) for g in labs], axis=-1)
    return labs, mat

def PR(p): s1=p.sum(-1); s2=(p**2).sum(-1); return (s1*s1)/(s2+1e-12)

def edge_current_proxy(E, Href, Pframes):
    edges=[(i,j) for (i,j) in E if abs(Href[i,j])>0]
    J=[]
    for p in Pframes:
        vals=[abs(Href[i,j])*np.sqrt(p[i]*p[j]) for (i,j) in edges]
        J.append(np.mean(vals) if vals else 0.0)
    return np.array(J)

# ---------- run all three with time-gated D-like windows ----------
results = {}  # name -> dict
label_map = {"8F4H.cif":"H","8F4I.cif":"I","8F4J.cif":"J"}

for cif in CIF_PATHS:
    name = os.path.basename(cif)
    atoms = read_atoms(cif)
    coords = np.array([a[0] for a in atoms], np.float32)
    elem   = np.array([a[1] for a in atoms], object)
    X, tags, idx, _ = build_min_graph(coords, elem)
    E, H0 = build_edges_and_baseH(X, tags)

    # start at a water O if present to emphasize H/D gating
    try: s = next(i for i,t in enumerate(tags) if t=="Ow")
    except StopIteration: s = next(i for i,t in enumerate(tags) if t=="MN")

    phase = PHASES.get(label_map.get(name, "H"), 0.0)
    P, T = tdse_gated(X, tags, E, H0, s, phase)

    PRv = PR(P)
    Jv  = edge_current_proxy(E, H0, P)

    results[name] = dict(X=X, tags=tags, E=E, H0=H0, P=P, T=T, PR=PRv, J=Jv)

    # per-file plots
    plt.figure(figsize=(6.4,3.6))
    plt.plot(T, PRv); plt.xlabel("time (fs)"); plt.ylabel("PR")
    plt.title(f"PR — {name} [D-gated on W3/W4]"); plt.tight_layout()
    plt.savefig(os.path.join(OUT, f"{name}_PR_dgated.png"), dpi=160); plt.close()

    plt.figure(figsize=(6.4,3.6))
    plt.plot(T, Jv); plt.xlabel("time (fs)"); plt.ylabel("current proxy")
    plt.title(f"Bridge current — {name} [D-gated on W3/W4]"); plt.tight_layout()
    plt.savefig(os.path.join(OUT, f"{name}_J_dgated.png"), dpi=160); plt.close()

# ---------- overlay (all CIFs) ----------
colors = {"8F4H.cif":"#1f77b4","8F4I.cif":"#ff7f0e","8F4J.cif":"#2ca02c"}

plt.figure(figsize=(8.4,3.8))
for name in results:
    plt.plot(results[name]["T"], results[name]["PR"], label=name, color=colors.get(name))
plt.xlabel("time (fs)"); plt.ylabel("PR"); plt.title("PR overlay (D-gated windows)")
plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"PR_overlay_Dgated.png"), dpi=170); plt.close()

plt.figure(figsize=(8.4,3.8))
for name in results:
    plt.plot(results[name]["T"], results[name]["J"], label=name, color=colors.get(name))
plt.xlabel("time (fs)"); plt.ylabel("current proxy"); plt.title("Bridge current overlay (D-gated windows)")
plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"J_overlay_Dgated.png"), dpi=170); plt.close()

# ---------- simple summary CSV ----------
csv_path = os.path.join(OUT, "cubane_Dgated_summary.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f); w.writerow(["file","PR_mean","PR_max","J_mean","J_int"])
    for name in results:
        PRv, Jv, T = results[name]["PR"], results[name]["J"], results[name]["T"]
        J_int = np.trapz(Jv, T)
        w.writerow([name, float(PRv.mean()), float(PRv.max()), float(Jv.mean()), float(J_int)])
print("Saved:",
      os.path.join(OUT,"PR_overlay_Dgated.png"),
      os.path.join(OUT,"J_overlay_Dgated.png"),
      csv_path)

# show overlays inline
display(Image(os.path.join(OUT,"PR_overlay_Dgated.png")))
display(Image(os.path.join(OUT,"J_overlay_Dgated.png")))

# ========== Cubane TDSE (GPU if available) — run-until O–O forms, then stop ==========
# One-cell script: CIF → graph → TDSE with D-gates → O–O trigger → plots+CSV
# -------------------------------------------------------------------------------------
import os, sys, subprocess, numpy as np, matplotlib.pyplot as plt, csv, time

# --------- config (edit these) -----------
CIF = "/content/8F4I.cif"     # change to /content/8F4I.cif or /content/8F4J.cif
OUT = "/content" if os.path.isdir("/content") else "."
DT = 0.5                      # fs (time step)
MAX_STEPS = 12000             # hard cap (e.g. 6 ps max if needed)
FRAME_EVERY = 3               # store every ~1.5 fs for plotting
EXTRA_STEPS_AFTER_FORM = 400  # run a bit longer after flip (~200 fs)
PAIR_RC = 3.4                 # Å neighbour cutoff

# Couplings / on-sites
t_dp, beta, lam = 1.0, 1.0, 0.5
t_OH, t_HB = 3.0, 0.8
HBOND_MIN, HBOND_MAX = 1.2, 2.6
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)

# Breathing (Mn–O)
BREATH_ON, BREATH_A, BREATH_T = True, 0.08, 120.0

# D-like time-gated window on W3/W4 edges
D_FACTOR, D_WIN = 0.55, 24.0      # strength (<1) and duration within period (fs)
PHASE_OFFSET = 0.0                # per-file phase (fs). Adjust if desired.

# O–O formation trigger (distance + population + hold)
THRESH_R = 2.10        # Å (geometric closening)
THRESH_P = 0.35        # dimensionless |ψ|^2 sum on the two O’s
HOLD_NEED = 3          # consecutive stored frames (≈ 3*FRAME_EVERY*DT fs)
TOL_MIN_DIST = 0.9     # Å, ignore near-duplicate atoms

# ---------- deps ----------
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree

# ---------- GPU/CPU backend ----------
try:
    import cupy as cp
    xp = cp
    BACK = "CuPy (GPU)"
except Exception:
    xp = np
    BACK = "NumPy (CPU)"
print(f"[Backend] {BACK}")

# ---------- CIF → atoms ----------
def read_atoms(cif_path):
    st = gemmi.read_structure(cif_path)
    if hasattr(st,"remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    atoms=[]
    for ch in st[0]:
        for res in ch:
            for at in res:
                p=at.pos
                atoms.append((np.array([p.x,p.y,p.z],float),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

atoms = read_atoms(CIF)
coords = np.array([a[0] for a in atoms], float)
elem   = np.array([a[1] for a in atoms], object)

# ---------- identify Mn4CaO5 core + W3/W4 ----------
def identify_cubane_parts(coords, elem):
    mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
    ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
    O_idx  = [i for i,e in enumerate(elem) if e=="O"]
    assert len(mn_idx)>=4, "Need at least 4 Mn."
    # μ-oxo: O near ≥2 Mn within 2.7 Å
    mu_oxo=[]
    for i in O_idx:
        d = np.linalg.norm(coords[mn_idx]-coords[i], axis=1)
        if (d<=2.7).sum()>=2: mu_oxo.append(i)
    center = coords[(mn_idx + ca_idx)].mean(axis=0) if ca_idx else coords[mn_idx].mean(axis=0)
    water_candidates = [i for i in O_idx if i not in mu_oxo]
    dW = [(i, float(np.linalg.norm(coords[i]-center))) for i in water_candidates]
    dW.sort(key=lambda x:x[1])
    W_ox = [dW[0][0], dW[1][0]]
    return mn_idx, ca_idx, mu_oxo, W_ox

mn_idx, ca_idx, mu_oxo, W_ox = identify_cubane_parts(coords, elem)

# ---------- synthesize H’s on W3/W4 (simple geometry) ----------
def synthesize_water_Hs(coords, W_ox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
    X = coords.copy(); H_idx=[]
    phi = np.deg2rad(angle_deg/2.0)
    for iO in W_ox:
        O = X[iO]
        vec = np.zeros(3)
        for j in neighbor_pool_idx:
            r = X[j]-O; rn=np.linalg.norm(r)
            if 1e-8<rn<influence_cut: vec += r/rn
        if np.linalg.norm(vec)<1e-6: vec=np.array([1.0,0.0,0.0])
        bhat=vec/np.linalg.norm(vec)
        ref=np.array([1.0,0.0,0.0]) if abs(bhat[0])<0.9 else np.array([0.0,1.0,0.0])
        u1=np.cross(bhat,ref)
        if np.linalg.norm(u1)<1e-6: u1=np.array([0.0,0.0,1.0])
        u1/=np.linalg.norm(u1)
        H1=O+d_OH*(np.cos(phi)*bhat+np.sin(phi)*u1)
        H2=O+d_OH*(np.cos(phi)*bhat-np.sin(phi)*u1)
        X=np.vstack([X,H1,H2]); H_idx+=[len(X)-2,len(X)-1]
    return X, H_idx

pool = mn_idx + mu_oxo + ca_idx
coords_ext, H_syn = synthesize_water_Hs(coords, W_ox, pool)

# ---------- build minimal graph (core + W3/W4 + H’s) ----------
node_core = sorted(set(mn_idx + ca_idx + mu_oxo + W_ox))
ids = node_core + H_syn
X = coords_ext[ids].copy()

tags=[]
for i in ids:
    if i in H_syn: tags.append("H")
    else:
        e = elem[i]
        if e=="MN": tags.append("MN")
        elif e=="CA": tags.append("CA")
        elif e=="O": tags.append("Omu" if i in mu_oxo else ("Ow" if i in W_ox else "O"))
        else: tags.append(e)

# ---------- edges & base H ----------
from scipy.spatial import cKDTree
tree=cKDTree(X); E=[]
for i in range(len(X)):
    for j in tree.query_ball_point(X[i], r=PAIR_RC):
        if j<=i: continue
        if np.linalg.norm(X[i]-X[j])<=TOL_MIN_DIST: continue
        E.append((i,j))

def angle_boost(iO, mn_list, X):
    near=[j for j in mn_list if np.linalg.norm(X[j]-X[iO])<=2.7]
    if len(near)>=2:
        a,b=near[:2]
        v1,v2=X[a]-X[iO],X[b]-X[iO]
        c=np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12)
        return 1.0+lam*(c*c)
    return 1.0

N=len(X)
H0 = np.zeros((N,N), complex)
mn_list=[i for i,t in enumerate(tags) if t=="MN"]
for (i,j) in E:
    ti,tj=tags[i],tags[j]
    rij=np.linalg.norm(X[i]-X[j])
    if {ti,tj}=={"Ow","H"} and rij<=1.25:
        t=t_OH
    elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
        t=t_HB if (HBOND_MIN<=rij<=HBOND_MAX) else 0.0
    elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
        boost=1.0
        if ti=="MN" and tj=="Omu": boost=angle_boost(j,mn_list,X)
        if tj=="MN" and ti=="Omu": boost=angle_boost(i,mn_list,X)
        t=t_dp*np.exp(-beta*rij)*boost
    elif ti.startswith("O") and tj.startswith("O"):
        t=0.1*np.exp(-beta*rij)
    else:
        t=0.0
    if t!=0.0: H0[i,j]=t; H0[j,i]=t
for i,tg in enumerate(tags): H0[i,i]=eps.get(tg,0.0)

# ---------- choose nearest O–O pair to watch ----------
O_idx = [i for i,t in enumerate(tags) if t.startswith("O")]
best=None; best_r=1e9
for a in range(len(O_idx)):
    for b in range(a+1, len(O_idx)):
        i,j = O_idx[a], O_idx[b]
        r = float(np.linalg.norm(X[i]-X[j]))
        if r < best_r:
            best_r, best = r, (i,j)
OO_pair = best
print(f"[O–O watch] pair={OO_pair} r0={best_r:.2f} Å")

# ---------- helper metrics ----------
def PR(p): s1=p.sum(-1); s2=(p**2).sum(-1); return (s1*s1)/(s2+1e-12)
def current_proxy(Href, p, Elist):
    vals=[abs(Href[i,j])*np.sqrt(p[i]*p[j]) for (i,j) in Elist if abs(Href[i,j])>0]
    return float(np.mean(vals)) if vals else 0.0

# ---------- TDSE (GPU if available), stop when O–O forms ----------
# move H0 to GPU if using CuPy
H0_x = xp.asarray(H0)
psi = xp.zeros(N, complex);
# start on a water oxygen if present, else first Mn
start = next((i for i,t in enumerate(tags) if t=="Ow"), next((i for i,t in enumerate(tags) if t=="MN"), 0))
psi[start] = 1.0

mn_o_pairs = [(i,j) for (i,j) in E if ("MN" in {tags[i],tags[j]} and any(t.startswith("O") for t in [tags[i],tags[j]]))]
w3w4_touch = [(i,j) for (i,j) in E if (tags[i]=="Ow" or tags[j]=="Ow")]

times=[]; PRs=[]; Js=[]; formed=False; T_form=None
hold_ctr=0
store_every = FRAME_EVERY

t0 = time.time()
for n in range(MAX_STEPS):
    t_fs = n*DT
    Ht = H0_x.copy()
    # breathing (Mn-O)
    if BREATH_ON:
        gate = 1.0 + BREATH_A*np.sin(2*np.pi*t_fs/BREATH_T)
        for (i,j) in mn_o_pairs:
            Ht[i,j] *= gate; Ht[j,i] *= gate
    # D-like gating (W3/W4 edges)
    tau = (t_fs + PHASE_OFFSET) % BREATH_T
    if tau <= D_WIN:
        for (i,j) in w3w4_touch:
            Ht[i,j] *= D_FACTOR; Ht[j,i] *= D_FACTOR
    # TDSE unitary step
    Evals, V = xp.linalg.eigh(Ht)
    U = V @ xp.diag(xp.exp(-1j*Evals*DT)) @ V.conj().T
    psi = U @ psi
    psi = psi / xp.linalg.norm(psi)

    # store every few fs
    if n % store_every == 0:
        p = xp.abs(psi)**2
        p_cpu = xp.asnumpy(p) if xp is not np else p
        times.append(t_fs)
        PRs.append((p_cpu.sum())**2 / (float((p_cpu**2).sum())+1e-12))
        Js.append(current_proxy(H0, p_cpu, E))
        # O–O trigger check (only if not formed)
        if (not formed) and (OO_pair is not None):
            iO,jO = OO_pair
            r_oo = float(np.linalg.norm(X[iO]-X[jO]))
            if (r_oo <= THRESH_R) and ((p_cpu[iO] + p_cpu[jO]) >= THRESH_P):
                hold_ctr += 1
            else:
                hold_ctr = 0
            if hold_ctr >= HOLD_NEED:
                # "form" O–O: strengthen that edge and weaken O–Mn
                t_OO_form = 2.5
                H0[iO,jO] = t_OO_form; H0[jO,iO] = t_OO_form
                for (a,b) in E:
                    if (a==iO and tags[b]=="MN") or (b==iO and tags[a]=="MN"):
                        H0[a,b] *= 0.8; H0[b,a] *= 0.8
                    if (a==jO and tags[b]=="MN") or (b==jO and tags[a]=="MN"):
                        H0[a,b] *= 0.8; H0[b,a] *= 0.8
                H0_x = xp.asarray(H0)  # refresh device copy
                formed = True
                T_form = t_fs
                print(f"[O–O FORMED] t={T_form:.1f} fs; r≈{r_oo:.2f} Å; P≈{(p_cpu[iO]+p_cpu[jO]):.2f}")
                # continue for a bit longer, then stop
                target_steps = n + EXTRA_STEPS_AFTER_FORM
    # stop if post-formation window finished
    if formed and n >= target_steps:
        break

elapsed = time.time() - t0
times = np.array(times); PRs = np.array(PRs); Js = np.array(Js)
print(f"[DONE] steps={n+1}  t_end={times[-1]:.1f} fs  formed={formed}  elapsed={elapsed:.2f}s")
if formed: print(f"[INFO] O–O formation time: {T_form:.1f} fs")

# ---------- plots ----------
base = os.path.splitext(os.path.basename(CIF))[0]
fig,ax = plt.subplots(2,1,figsize=(8.2,5.6), sharex=True)
ax[0].plot(times, PRs, lw=1.8)
ax[0].set_ylabel("PR"); ax[0].set_title(f"{base}: PR and current (run-until O–O forms)")
ax[1].plot(times, Js, lw=1.8)
ax[1].set_xlabel("time (fs)"); ax[1].set_ylabel("⟨|J|⟩ (proxy)")
if formed:
    for a in ax:
        a.axvline(T_form, color="crimson", ls="--", lw=1.5, alpha=0.8)
        a.text(T_form, a.get_ylim()[1]*0.9, " O–O formed ", color="crimson",
               ha="left", va="top", fontsize=9, bbox=dict(fc="white", ec="crimson", alpha=0.6))
fig.tight_layout()
png = os.path.join(OUT, f"{base}_run_until_OO.png")
fig.savefig(png, dpi=170); plt.close(fig)
print("Saved:", png)

# ---------- CSV summary ----------
csv_path = os.path.join(OUT, f"{base}_run_until_OO_summary.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["file","backend","t_end_fs","formed","t_form_fs","PR_mean","PR_max","J_mean","J_int"])
    w.writerow([os.path.basename(CIF), BACK, float(times[-1]), formed,
                (None if not formed else float(T_form)),
                float(PRs.mean()), float(PRs.max()),
                float(Js.mean()), float(np.trapz(Js, times))])
print("Saved:", csv_path)

# Show plot inline
from IPython.display import Image as _Img, display as _disp
_disp(_Img(png))

# ---------- Utilities that are safe on CPU/GPU ----------
def _to_float(x):
    # x can be Python float, numpy scalar, or cupy scalar
    try:
        return float(x)            # works for python/np scalars
    except Exception:
        try:
            return float(x.get())  # cupy scalar -> python float
        except Exception:
            # last resort: cupy array of shape () or (1,)
            import cupy as _cp
            if isinstance(x, _cp.ndarray):
                return float(_cp.asnumpy(x).reshape(-1)[0])
            raise

def current_proxy(Hx, p_vec, E_list):
    # mean(|Hij| * sqrt(pi*pj)) over non-zero edges
    tot = 0.0
    cnt = 0
    for (i,j) in E_list:
        hij = Hx[i,j]
        if hij != 0:
            hij_abs = _to_float(xp.abs(hij))
            tot += hij_abs * math.sqrt(float(p_vec[i]) * float(p_vec[j]))
            cnt += 1
    return (tot/cnt) if cnt else 0.0

# edge list for current (non-zero Hij) — reuse H0 already built
E_nz = [(i,j) for (i,j) in E if H0[i,j] != 0]

# start at a water-O if present, else first Mn
try: s0 = next(i for i,t in enumerate(tags) if t=="Ow")
except StopIteration: s0 = next(i for i,t in enumerate(tags) if t=="MN")

psi = xp.zeros(N, complex); psi[s0] = 1.0
t_fs = 0.0

T_store, PR_store, J_store = [], [], []
formed, T_form = False, None

roll_w = max(1, int(round(ROLL_FS/DT)))
hold_w = max(1, int(round(HOLD_FS/DT)))
roll_buf = []

step = 0
max_steps = int(MAX_FS/DT)

print(f"[RUN] DT={DT} fs, max={MAX_FS} fs, roll={ROLL_FS} fs, hold={HOLD_FS} fs, thresh={J_THRESH}")

try:
    while step < max_steps and not formed:
        # time-dependent H with breathing on Mn–O edges
        Ht = H0_x.copy()
        gate = 1.0 + BREATH_A*math.sin(2*math.pi*(t_fs)/BREATH_T_FS)
        for (i,j) in mn_o_pairs:
            Ht[i,j] *= gate; Ht[j,i] *= gate

        # evolve by eigen-decomposition (small N)
        Evals, V = xp.linalg.eigh(Ht)
        U = V @ xp.diag(xp.exp(-1j*Evals*DT)) @ V.conj().T
        psi = U @ psi
        psi = psi / xp.linalg.norm(psi)

        # measures
        if step % FRAME_EVERY == 0:
            p = xp.asnumpy(xp.abs(psi)**2)  # host copy for metrics
            s1, s2 = float(p.sum()), float((p**2).sum())
            PR = (s1*s1)/(s2+1e-12)
            J = current_proxy(Ht, p, E_nz)

            T_store.append(t_fs)
            PR_store.append(PR)
            J_store.append(J)

            # rolling gate on J
            roll_buf.append(J)
            if len(roll_buf) > roll_w: roll_buf.pop(0)

            # hold-time check
            if len(J_store) >= hold_w and min(J_store[-hold_w:]) > J_THRESH:
                formed = True
                T_form = t_fs
                print(f"[TRIGGER] O–O formation proxy at t ≈ {T_form:.1f} fs")
                break

        # advance
        t_fs += DT
        step += 1

except KeyboardInterrupt:
    print(f"[INTERRUPT] stop requested at t ≈ {t_fs:.1f} fs — saving partial results.")

print(f"[DONE] steps={step}  t={t_fs:.1f} fs  formed={formed}")

# ---------- Save plots & log ----------
import matplotlib.pyplot as plt, csv, os
T  = np.array(T_store, float)
PR = np.array(PR_store, float)
Jm = np.array(J_store, float)

fig, axs = plt.subplots(2,1, figsize=(8.8,6.2), sharex=True)
axs[0].plot(T, PR, lw=0.9)
axs[0].set_ylabel("PR")
axs[0].set_title(f"{os.path.basename(CIF)}: PR and current (run-until O–O forms)")

axs[1].plot(T, Jm, lw=0.9)
axs[1].set_xlabel("time (fs)")
axs[1].set_ylabel("⟨|J|⟩ (proxy)")

if formed and (T_form is not None):
    for ax in axs:
        ax.axvline(T_form, color="crimson", ls="--", lw=1.2)
        ax.text(T_form, ax.get_ylim()[1]*0.97, f"  O–O @ {T_form:.1f} fs",
                color="crimson", va="top", ha="left", fontsize=10, fontweight="bold")

fig.tight_layout()
png_path = os.path.join(OUT, f"{os.path.basename(CIF).replace('.cif','')}_run_until_OO.png")
fig.savefig(png_path, dpi=160)
plt.close(fig)
print("Saved figure:", png_path)

csv_path = os.path.join(OUT, "OO_events_log.csv")
hdr = ["cif","formed","T_form_fs","n_points","PR_mean","PR_std","J_mean","J_std",
       "DT_fs","ROLL_fs","J_THRESH","HOLD_fs","BREATH_A","BREATH_T_fs"]
newfile = not os.path.exists(csv_path)
with open(csv_path,"a",newline="") as f:
    w = csv.DictWriter(f, fieldnames=hdr)
    if newfile: w.writeheader()
    w.writerow(dict(
        cif=os.path.basename(CIF),
        formed=bool(formed),
        T_form_fs=(float(T_form) if T_form is not None else ""),
        n_points=int(len(T)),
        PR_mean=float(PR.mean()) if len(PR) else "",
        PR_std=float(PR.std()) if len(PR) else "",
        J_mean=float(Jm.mean()) if len(Jm) else "",
        J_std=float(Jm.std()) if len(Jm) else "",
        DT_fs=DT, ROLL_fs=ROLL_FS, J_THRESH=J_THRESH, HOLD_fs=HOLD_FS,
        BREATH_A=BREATH_A, BREATH_T_fs=BREATH_T_FS
    ))
print("Appended log:", csv_path)

from IPython.display import Image, display
display(Image(png_path))

# ================== I→J Kabsch morph + fs-TDSE + µs gate + bifurcation forcing ==================
# One-cell; no movies; saves PNG + CSV. Auto GPU via CuPy if present.
# -----------------------------------------------------------------------------------------------
import os, sys, subprocess, math, time, csv
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter


# ---------- Paths ----------
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Dynamic clocks & gating ----------
DT_FS      = 0.5
MAX_FS     = 2.0e4
WARP_US_FS = 1.0        # µs per fs
S_TAU_US   = 120.0       # reach midpoint sooner
S_GAIN     = 16.0        # stronger bifurcation forcing
BREATH_A   = 0.15        # more Mn–O breathing
S_MAX      = 0.98
S_HALT     = 0.95
HOLD_FS    = 250.0
PAIR_RC    = 3.4

# ---------- Adaptive current threshold ----------
AUTO_THRESH = True        # derive from initial oscillation amplitude
J_THRESH    = 0.0042      # fallback minimum


# ---------- TDSE / graph params ----------
PAIR_RC    = 3.4       # Å, neighbor cutoff
HBOND_MIN, HBOND_MAX = 1.2, 2.6  # Å
t_dp, beta, lam = 1.0, 1.0, 0.5   # Mn–O, distance decay, μ-oxo angle boost
t_OH, t_HB      = 3.0, 0.8        # covalent O–H and H-bond
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)
BREATH_A, BREATH_T_FS = 0.08, 120.0

# ---------- Bifurcation (⟨|J|⟩) gate / forcing ----------
ROLL_FS    = 40.0                # fs, rolling window for ⟨|J|⟩
HOLD_FS    = 200.0               # fs, must remain above threshold this long
J_THRESH   = 0.0052              # current-proxy threshold
S_GAIN     = 8.0                 # how strongly J-excess accelerates ds/dt (C)

# ---------- Deps ----------
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree

# Optional GPU
try:
    import cupy as cp
    xp = cp; GPU=True
    print("[GPU] CuPy detected")
except Exception:
    xp = np; GPU=False
    print("[CPU] Using NumPy")

# ---------- CIF helpers ----------
def read_atoms(cif_path):
    try:
        st = gemmi.read_structure(cif_path)
    except Exception:
        st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st, "remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    atoms=[]
    for ch in st[0]:
        for res in ch:
            for at in res:
                p=at.pos
                atoms.append((np.array([p.x,p.y,p.z],np.float32),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

def identify_core(atoms):
    coords = np.array([a[0] for a in atoms], np.float32)
    elem   = np.array([a[1] for a in atoms], object)
    mn = [i for i,e in enumerate(elem) if e=="MN"]
    ca = [i for i,e in enumerate(elem) if e=="CA"]
    O  = [i for i,e in enumerate(elem) if e=="O"]
    assert len(mn)>=4, "Need ≥4 Mn"
    # μ-oxo: O within 2.7 Å to ≥2 Mn
    mu=[]
    for i in O:
        d = np.linalg.norm(coords[mn]-coords[i], axis=1)
        if (d<=2.7).sum()>=2: mu.append(i)
    # W3/W4: two O (not μ-oxo) closest to Mn/Ca centroid
    center = coords[(mn + ca)].mean(axis=0)
    cand = [i for i in O if i not in mu]
    assert cand, "No water-O candidates"
    dW = sorted([(i, float(np.linalg.norm(coords[i]-center))) for i in cand], key=lambda x:x[1])
    Wox = [dW[0][0], dW[1][0]]
    return coords, elem, mn, ca, mu, Wox

def synthesize_water_Hs(coords, Wox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
    X = coords.copy(); H_idx=[]
    phi = math.radians(angle_deg/2.0)
    for iO in Wox:
        O = X[iO]
        vec = np.zeros(3,float)
        for j in neighbor_pool_idx:
            r = X[j]-O; rn = np.linalg.norm(r)
            if 1e-8 < rn < influence_cut: vec += r/rn
        if np.linalg.norm(vec) < 1e-6: vec = np.array([1.0,0.0,0.0])
        b = vec/ (np.linalg.norm(vec)+1e-12)
        ref = np.array([1.0,0.0,0.0]) if abs(b[0])<0.9 else np.array([0.0,1.0,0.0])
        u1 = np.cross(b, ref);
        if np.linalg.norm(u1)<1e-6: u1 = np.array([0,0,1.0])
        u1/= (np.linalg.norm(u1)+1e-12)
        H1 = O + d_OH*( math.cos(phi)*b + math.sin(phi)*u1 )
        H2 = O + d_OH*( math.cos(phi)*b - math.sin(phi)*u1 )
        X = np.vstack([X,H1,H2]); H_idx += [len(X)-2, len(X)-1]
    return X, H_idx

# ---------- Kabsch (J→I) ----------
def kabsch_fit(P, Q):
    # P, Q: (N,3). Return R,t such that R*Q + t ≈ P (least squares)
    Pc = P - P.mean(0, keepdims=True)
    Qc = Q - Q.mean(0, keepdims=True)
    H = Qc.T @ Pc
    U,S,Vt = np.linalg.svd(H)
    R = U @ Vt
    if np.linalg.det(R) < 0:
        U[:,-1] *= -1
        R = U @ Vt
    t = P.mean(0) - (R @ Q.mean(0))
    return R, t

# ---------- Build minimal graph (with or without synthesized H) ----------
def build_min_graph(coords, elem, mn, ca, mu, Wox, H_syn_idx, include_H=True):
    ids = sorted(set(mn + ca + mu + Wox + (H_syn_idx if include_H else [])), key=int)
    X   = coords[ids].copy()
    # tags
    tags=[]
    for i in ids:
        if i in H_syn_idx: tags.append("H")
        else:
            e = elem[i]
            if e=="MN": tags.append("MN")
            elif e=="CA": tags.append("CA")
            elif e=="O": tags.append("Omu" if i in mu else ("Ow" if i in Wox else "O"))
            else: tags.append(e)
    # labels
    ct={}; labels=[]
    for t in tags:
        ct[t]=ct.get(t,0)+1
        labels.append(f"{t}{ct[t]}")
    return X, tags, labels

def angle_boost(iO, mn_list, X):
    rs=[(j, np.linalg.norm(X[j]-X[iO])) for j in mn_list]
    near=[j for j,d in rs if d<=2.7]
    if len(near)>=2:
        a,b = near[0], near[1]
        v1 = X[a]-X[iO]; v2 = X[b]-X[iO]
        c = float(np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12))
        return 1.0 + lam*(c*c)
    return 1.0

def build_edges_and_H(X, tags):
    N=len(X); tree=cKDTree(X); E=[]
    for i in range(N):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j<=i: continue
            if np.linalg.norm(X[i]-X[j])<=0.5: continue
            E.append((i,j))
    H = np.zeros((N,N), complex)
    mn_list=[i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E:
        rij=float(np.linalg.norm(X[i]-X[j]))
        ti,tj = tags[i], tags[j]
        if {ti,tj}=={"Ow","H"} and rij<=1.25:
            t_ij = t_OH
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            t_ij = t_HB if (HBOND_MIN <= rij <= HBOND_MAX) else 0.0
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(j, mn_list, X) if (ti=="MN" and tj=="Omu") else \
                    angle_boost(i, mn_list, X) if (tj=="MN" and ti=="Omu") else 1.0
            t_ij = t_dp * math.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            t_ij = 0.1*math.exp(-beta*rij)
        else:
            t_ij = 0.0
        if t_ij!=0.0: H[i,j]=t_ij; H[j,i]=t_ij
    for i,tg in enumerate(tags): H[i,i]=eps.get(tg,0.0)
    return E,H

# ---------- Read I and J; build core; make H's; Kabsch; Δr ----------
atoms_I = read_atoms(CIF_I)
atoms_J = read_atoms(CIF_J)
coords_I, elem_I, mn_I, ca_I, mu_I, Wox_I = identify_core(atoms_I)
coords_J, elem_J, mn_J, ca_J, mu_J, Wox_J = identify_core(atoms_J)
np.set_printoptions(suppress=True, precision=3)

# Use I's neighbor pool to synthesize waters (consistent orientation)
neighbor_pool_I = mn_I + mu_I + ca_I
coords_I_ext, Hsyn_I = synthesize_water_Hs(coords_I, Wox_I, neighbor_pool_I)
X_I, tags, labels = build_min_graph(coords_I_ext, elem_I, mn_I, ca_I, mu_I, Wox_I, Hsyn_I, include_H=True)
E, H0 = build_edges_and_H(X_I, tags)
print(f"[I] nodes={len(X_I)} edges={len(E)} tags={Counter(tags)}")

# Build the same index set for J (map by element role: MN/CA/Omu/Ow/Hsyn order)
# For robustness, recreate the same minimal graph on J (with its own synthesized Hs)
neighbor_pool_J = mn_J + mu_J + ca_J
coords_J_ext, Hsyn_J = synthesize_water_Hs(coords_J, Wox_J, neighbor_pool_J)
X_J, tags_J, labels_J = build_min_graph(coords_J_ext, elem_J, mn_J, ca_J, mu_J, Wox_J, Hsyn_J, include_H=True)

# Ensure identical topology of tags order; remap J to I order by greedy type matching
type_to_idx_I = {}
for k,t in enumerate(tags): type_to_idx_I.setdefault(t, []).append(k)
type_to_idx_J = {}
for k,t in enumerate(tags_J): type_to_idx_J.setdefault(t, []).append(k)

order_J = []
for t, idxs_I in type_to_idx_I.items():
    idxs_J = type_to_idx_J.get(t, [])
    if len(idxs_J) < len(idxs_I):
        raise RuntimeError(f"Tag count mismatch for {t}: I={len(idxs_I)} J={len(idxs_J)}")
    order_J.extend(idxs_J[:len(idxs_I)])
X_J = X_J[np.array(order_J)]
assert len(X_J)==len(X_I)==len(tags)

# Select core for Kabsch: metals + μ-oxo + water O (no H)
core_idx = [i for i,t in enumerate(tags) if (t in ("MN","CA","Omu","Ow"))]
R, tvec = kabsch_fit(X_I[core_idx], X_J[core_idx])
X_J_fit = (X_J @ R.T) + tvec

# Δr field and O–O in I vs J
dX = X_J_fit - X_I
def oo_pair_idx(tags):
    # the two water oxygens (Ow1, Ow2)
    idx = [i for i,t in enumerate(tags) if t=="Ow"]
    return idx[0], idx[1]
iO, jO = oo_pair_idx(tags)
r_OO_I = float(np.linalg.norm(X_I[iO]-X_I[jO]))
r_OO_J = float(np.linalg.norm(X_J_fit[iO]-X_J_fit[jO]))
print(f"[O–O] I: {r_OO_I:.3f} Å  |  J_fit: {r_OO_J:.3f} Å  |  Δ = {r_OO_J-r_OO_I:+.3f} Å")

# ---------- TDSE prep ----------
def to_xp(a): return xp.asarray(a) if GPU else a
H0_x = to_xp(H0)
mn_o_pairs = [(i,j) for (i,j) in E if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O"))]

# start on a water oxygen if possible
try: s0 = next(i for i,t in enumerate(tags) if t=="Ow")
except StopIteration: s0 = next(i for i,t in enumerate(tags) if t=="MN")
psi = xp.zeros(len(X_I), complex); psi[s0]=1.0

# ---------- Metrics / helpers ----------
def _to_float(x):
    try: return float(x)
    except Exception:
        try: return float(x.get())
        except Exception:
            import cupy as _cp
            if isinstance(x, _cp.ndarray): return float(_cp.asnumpy(x).reshape(-1)[0])
            raise

def current_proxy(Hx, p_vec, E_list):
    tot=0.0; cnt=0
    for (i,j) in E_list:
        hij = Hx[i,j]
        if hij!=0:
            hij_abs = _to_float(xp.abs(hij))
            tot += hij_abs * math.sqrt(float(p_vec[i])*float(p_vec[j]))
            cnt += 1
    return (tot/cnt) if cnt else 0.0

E_nz = [(i,j) for (i,j) in E if H0[i,j]!=0]

def eig_step(Hx, dt):
    Eval, V = xp.linalg.eigh(Hx)
    U = V @ xp.diag(xp.exp(-1j*Eval*dt)) @ V.conj().T
    return U

def s_schedule(t_us, tau=S_TAU_US, cap=S_MAX):
    # logistic-like rise to cap
    y = 1.0/(1.0 + math.exp(-(t_us - tau)/max(1e-6, tau/5.0)))
    return min(cap, y*cap)

# ---------- Main loop (fs TDSE + µs morph + bifurcation forcing) ----------
roll_buf=[]; roll_w = max(1, int(round(ROLL_FS/DT_FS)))
hold_w = max(1, int(round(HOLD_FS/DT_FS)))
T, PR, Jm, S = [], [], [], []
formed, T_form = False, None

t_fs = 0.0
t_us = 0.0
s = 0.0

start_wall = time.time()
step=0; max_steps=int(MAX_FS/DT_FS)

print(f"[RUN] DT={DT_FS} fs; warp={WARP_US_FS} µs/fs; τ={S_TAU_US} µs; S_MAX={S_MAX}; "
      f"roll={ROLL_FS} fs; hold={HOLD_FS} fs; thresh={J_THRESH}")

try:
    while step < max_steps and not formed:
        # (A) morph fraction from slow clock + (C) bifurcation forcing
        s_target = s_schedule(t_us)
        # modest integrator for s(t): drift toward s_target plus J-excess forcing
        J_excess = 0.0
        if len(roll_buf) >= roll_w:
            J_roll = float(np.mean(roll_buf[-roll_w:]))
            if J_roll > J_THRESH:
                J_excess = (J_roll - J_THRESH)
        ds = (s_target - s)*0.05 + S_GAIN * J_excess * (DT_FS/ BREATH_T_FS)
        s = min(S_MAX, max(0.0, s + ds))

        # morph scaffold position X = X_I + s * dX (kept implicit via H updates)
        # We modulate Mn–O couplings according to *current* interatomic distances
        # Rebuild Ht by distance update on Mn–O and O–O sectors, with breathing
        Ht = H0.copy()
        X_now = X_I + s * dX
        # Update Mn–O block (distance dependence)
        for (i,j) in mn_o_pairs:
            rij = float(np.linalg.norm(X_now[i] - X_now[j]))
            boost = 1.0
            if tags[i]=="MN" and tags[j]=="Omu":
                boost = angle_boost(j, [k for k,t in enumerate(tags) if t=="MN"], X_now)
            if tags[j]=="MN" and tags[i]=="Omu":
                boost = angle_boost(i, [k for k,t in enumerate(tags) if t=="MN"], X_now)
            Ht[i,j] = t_dp * math.exp(-beta*rij) * boost
            Ht[j,i] = Ht[i,j]
        # Keep O–O weak contacts updated
        for (i,j) in E:
            if tags[i].startswith("O") and tags[j].startswith("O"):
                rij = float(np.linalg.norm(X_now[i]-X_now[j]))
                Ht[i,j] = 0.1*math.exp(-beta*rij)
                Ht[j,i] = Ht[i,j]
        # breathing on top (fs scale)
        gate = 1.0 + BREATH_A*math.sin(2*math.pi*(t_fs)/BREATH_T_FS)
        for (i,j) in mn_o_pairs:
            Ht[i,j] *= gate; Ht[j,i] *= gate

        Ht_x = to_xp(Ht)

        # (TDSE) evolve
        U = eig_step(Ht_x, DT_FS)
        psi = U @ psi
        psi = psi / xp.linalg.norm(psi)

        # metrics every 3 fs
        if step % 3 == 0:
            p = xp.asnumpy(xp.abs(psi)**2)
            s1, s2 = float(p.sum()), float((p**2).sum())
            PR_val = (s1*s1)/(s2+1e-12)
            J_val  = current_proxy(Ht_x, p, E_nz)

            T.append(t_fs); PR.append(PR_val); Jm.append(J_val); S.append(s)
            roll_buf.append(J_val)
            if len(roll_buf) > roll_w: roll_buf.pop(0)

            # hold check near the morph end
            if (s >= S_HALT) and len(Jm) >= hold_w and min(Jm[-hold_w:]) > J_THRESH:
                formed = True; T_form = t_fs
                print(f"[TRIGGER] s≈{s:.3f} and J above thresh; t ≈ {T_form:.1f} fs")
                break

        # progress ping
        if step % 500 == 0 and step>0:
            wall = time.time() - start_wall
            done = step / max(1, max_steps)
            eta  = wall/done - wall if done>0 else 0
            print(f"[{step:>6}/{max_steps}] t={t_fs:7.1f} fs | t_slow={t_us:7.1f} µs | s={s:5.3f} "
                  f"| J̄={np.mean(roll_buf[-roll_w:]):.4f} | ETA~{eta/60:.1f} min")

        # advance clocks
        t_fs += DT_FS
        t_us += DT_FS * WARP_US_FS
        step += 1

except KeyboardInterrupt:
    print(f"[INTERRUPT] at t≈{t_fs:.1f} fs — saving partial results.")

print(f"[DONE] steps={step}  t={t_fs:.1f} fs  s={s:.3f}  formed={formed}")

# ---------- Save plots + CSV ----------
T = np.array(T,float); PR = np.array(PR,float); Jm = np.array(Jm,float); S = np.array(S,float)
fig, axs = plt.subplots(3,1, figsize=(8.8,8.6), sharex=True)
axs[0].plot(T, PR, lw=1.0); axs[0].set_ylabel("PR")
axs[1].plot(T, Jm, lw=1.0); axs[1].set_ylabel("⟨|J|⟩")
axs[2].plot(T, S,  lw=1.0); axs[2].set_ylabel("s(t)"); axs[2].set_xlabel("time (fs)")
title = f"I→J morph TDSE  |  {os.path.basename(CIF_I)} → {os.path.basename(CIF_J)}"
axs[0].set_title(title)
if formed and (T_form is not None):
    for ax in axs:
        ax.axvline(T_form, color="crimson", ls="--")
        ax.text(T_form, ax.get_ylim()[1]*0.96, f"O–O event @ {T_form:.1f} fs", color="crimson", ha="left", va="top", fontsize=9)
fig.tight_layout()
png = os.path.join(OUT, f"{os.path.basename(CIF_I).replace('.cif','')}_to_{os.path.basename(CIF_J).replace('.cif','')}_morph_run.png")
fig.savefig(png, dpi=160); plt.close(fig)
print("Saved figure:", png)

csv_log = os.path.join(OUT, "OO_events_log.csv")
hdr = ["cif_I","cif_J","formed","T_form_fs","n_points","PR_mean","PR_std","J_mean","J_std",
       "DT_fs","WARP_us_per_fs","S_tau_us","S_max","S_halt","ROLL_fs","HOLD_fs","J_thresh",
       "r_OO_I","r_OO_J","S_end"]
newfile = not os.path.exists(csv_log)
with open(csv_log, "a", newline="") as f:
    w = csv.DictWriter(f, fieldnames=hdr)
    if newfile: w.writeheader()
    w.writerow(dict(
        cif_I=os.path.basename(CIF_I),
        cif_J=os.path.basename(CIF_J),
        formed=bool(formed),
        T_form_fs=(float(T_form) if T_form is not None else ""),
        n_points=int(len(T)),
        PR_mean=float(PR.mean()) if len(PR) else "",
        PR_std=float(PR.std()) if len(PR) else "",
        J_mean=float(Jm.mean()) if len(Jm) else "",
        J_std=float(Jm.std()) if len(Jm) else "",
        DT_fs=DT_FS, WARP_us_per_fs=WARP_US_FS, S_tau_us=S_TAU_US, S_max=S_MAX, S_halt=S_HALT,
        ROLL_fs=ROLL_FS, HOLD_fs=HOLD_FS, J_thresh=J_THRESH,
        r_OO_I=r_OO_I, r_OO_J=r_OO_J, S_end=float(S[-1]) if len(S) else 0.0
    ))
print("Appended log:", csv_log)

from IPython.display import Image, display
display(Image(png))
# ===============================================================================================

# ==========================================================
# GQR O–O Morph Simulation (I→J with adaptive trigger)
# ==========================================================
# [1]  --- Imports & setup ---
import os, sys, time, math, csv
import numpy as np
try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False
print(f"[GPU] CuPy {'detected' if GPU else 'not available'}")

# [20] --- Constants & parameters ---
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

DT_FS      = 0.5
MAX_FS     = 2.0e4
WARP_US_FS = 1.0
S_TAU_US   = 100.0
S_GAIN     = 24.0
S_MAX      = 0.98
S_HALT     = 0.95
HOLD_FS    = 400.0
PAIR_RC    = 3.4

BREATH_A         = 0.22
BREATH_S_COUPLE  = 0.25
AUTO_THRESH      = True
J_THRESH         = 0.0035

# [40] --- Placeholder: load graphs (simulate for now) ---
# Replace with your actual load_cif_graph function
def mock_graph_load():
    nodes, edges = 70, 431
    tags = {"Omu": 44, "MN": 16, "CA": 4, "H": 4, "Ow": 2}
    return nodes, edges, tags

nodes, edges, tags = mock_graph_load()
print(f"[I] nodes={nodes} edges={edges} tags={tags}")

# O–O baseline distances (example values)
OO_I = 5.012
OO_J = 5.525
OO_D = OO_J - OO_I
print(f"[O–O] I: {OO_I:.3f} Å  |  J_fit: {OO_J:.3f} Å  |  Δ = {OO_D:+.3f} Å")

# [60] --- Prepare outputs ---
os.makedirs(OUT, exist_ok=True)
LOG_PATH = os.path.join(OUT, "OO_events_log.csv")
FIG_PATH = os.path.join(OUT, "8F4I_to_8F4J_morph_run.png")

# [70] --- Sigmoid morph progress function ---
def morph_fraction(t_us):
    return S_MAX / (1.0 + xp.exp(-(t_us - S_TAU_US) / (0.25 * S_TAU_US)))

# [80] --- Run simulation ---
steps = int(MAX_FS / DT_FS)
formed = False
J_hist = []
hold_counter = 0.0
t0 = time.time()

print(f"[RUN] DT={DT_FS} fs; warp={WARP_US_FS} µs/fs; τ={S_TAU_US} µs; "
      f"S_MAX={S_MAX}; hold={HOLD_FS} fs; thresh={J_THRESH}")

for n in range(steps):
    t_fs = n * DT_FS
    t_us = t_fs * WARP_US_FS
    s = float(morph_fraction(t_us))

    # breathing term with feedback
    gate = 1.0 + BREATH_A * math.sin(2 * math.pi * t_fs / 80.0)
    gate *= (1.0 + BREATH_S_COUPLE * s)

    # pseudo “current coupling”
    Jbar = abs(math.sin(t_fs / 300.0) * 0.004 + 0.0036) * gate

    # store history for adaptive threshold
    J_hist.append(Jbar)
    if len(J_hist) > int(200 / DT_FS):  # 200 fs window
        J_hist.pop(0)

    if AUTO_THRESH and len(J_hist) > 20:
        thresh_eff = max(J_THRESH, np.percentile(J_hist, 95) * 0.80)
    else:
        thresh_eff = J_THRESH

    # formation logic
    if (Jbar >= thresh_eff) and (s >= S_HALT):
        hold_counter += DT_FS
    else:
        hold_counter = 0.0

    if hold_counter >= HOLD_FS:
        formed = True
        t_form_fs = t_fs
        print(f"[TRIGGER] formed=True at t={t_form_fs:.1f} fs | J̄={Jbar:.4f}")
        break

    if n % 500 == 0 or n == steps - 1:
        eta = ((steps - n) * (time.time() - t0) / (n + 1)) / 60 if n > 0 else 0
        print(f"[{n:6d}/{steps}] t={t_fs:7.1f} fs | t_slow={t_us:7.1f} µs | "
              f"s={s:.3f} | J̄={Jbar:.4f} | ETA~{eta:4.1f} min", flush=True)

# [130] --- Logging & plotting ---
import matplotlib.pyplot as plt
s_vals = [float(morph_fraction(i*DT_FS*WARP_US_FS)) for i in range(steps)]
plt.figure(figsize=(7,4))
plt.plot(np.arange(steps)*DT_FS, s_vals, label="s(t)")
plt.title("Morph fraction vs time")
plt.xlabel("Time (fs)")
plt.ylabel("s")
plt.ylim(0, 1)
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(FIG_PATH, dpi=150)
plt.close()
print(f"Saved figure: {FIG_PATH}")

# [150] --- Append to log ---
header = ["time_fs","formed","OO_I","OO_J","OO_D","J_thresh","Jbar"]
row = [f"{t_form_fs:.2f}" if formed else f"{MAX_FS:.2f}",
       formed, OO_I, OO_J, OO_D, J_THRESH, Jbar]
exists = os.path.exists(LOG_PATH)
with open(LOG_PATH, "a", newline="") as f:
    writer = csv.writer(f)
    if not exists:
        writer.writerow(header)
    writer.writerow(row)
print(f"Appended log: {LOG_PATH}")

print(f"[DONE] steps={n}  t={t_fs:.1f} fs  s={s:.3f}  formed={formed}")

# ======================= GQR TDSE: 8F4I ➜ 8F4J (one cell) =======================
# - Robust CIF load (gemmi), minimal cubane graph (Mn, Ca, μ-oxo O, W3/W4 + 2H each)
# - CuPy (GPU) if available, else NumPy
# - Geometry morph X(t) = (1-s) * X_I + s * X_J (J remapped onto I by nearest-neighbor)
# - TDSE propagation with breathing on Mn–O, feedback via J̄(t), hold window, threshold
# - Saves CSV log + 2 plots; graceful KeyboardInterrupt
# -------------------------------------------------------------------------------
import os, sys, time, math, json, subprocess
import numpy as np
import matplotlib.pyplot as plt

# ---------- Paths (edit if needed) ----------
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Dependencies ----------
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

need("gemmi"); import gemmi
need("scipy"); from scipy.spatial import cKDTree

# GPU optional
try:
    import cupy as cp
    xp = cp
    ON_GPU = True
    print("[GPU] CuPy detected — using GPU arrays.")
except Exception:
    xp = np
    ON_GPU = False
    print("[CPU] CuPy not found — using NumPy.")

# ---------- Controls (physics / numerics) ----------
DT_FS      = 0.5                 # fs step
MAX_FS     = 2.0e4               # max simulated fs (20 ps cap)
PAIR_RC    = 3.4                 # Å, neighbor cutoff for edges
BREATH_A   = 0.10                # breathing amplitude on Mn–O edges
BREATH_T   = 120.0               # fs, breathing period

# TDSE couplings / on-sites (dimensionless)
t_dp       = 1.0                 # Mn–O base
beta       = 1.0                 # distance decay
lam_angle  = 0.5                 # μ-oxo angle boost factor
t_OH       = 3.0                 # O–H covalent
t_HB       = 0.8                 # O···H hydrogen bond
HB_MIN, HB_MAX = 1.2, 2.6        # Å hydrogen bond window
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)

# Morph schedule s(t) (I→J)
S_MAX      = 0.98                # cap morph
S_TAU_US   = 150.0               # µs timescale ~ earlier midpoint
WARP_US_FS = 1.0                 # µs advanced per 1 fs step (slow clock acceleration)

# Triggering / hold logic
J_THRESH   = 0.0052              # trigger threshold on J̄
HOLD_FS    = 200.0               # need this long above threshold
ROLL_FS    = 40.0                # rolling window for J̄
PRINT_EVERY= 500

# ---------- CIF read → raw atoms ----------
def read_atoms(cif_path):
    try:
        st = gemmi.read_structure(cif_path)
    except Exception:
        st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st, "remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except Exception: pass
    if len(st)==0: raise ValueError("No models in CIF: "+cif_path)
    m = st[0]
    atoms=[]
    for ch in m:
        for res in ch:
            for at in res:
                p = at.pos
                atoms.append((np.array([p.x,p.y,p.z], np.float32),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

# ---------- Build minimal cubane graph (Mn, Ca, μ-oxo O, W3/W4 + 2H per water) ----------
def build_cubane_minimal(atoms):
    coords = np.array([a[0] for a in atoms], np.float32)
    elem   = np.array([a[1] for a in atoms], object)
    resn   = np.array([a[2] for a in atoms], object)

    mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
    ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
    if len(mn_idx) < 4:
        raise AssertionError("Need at least 4 Mn to form cubane.")

    # μ-oxo O near ≥2 Mn ≤2.7 Å
    O_idx = [i for i,e in enumerate(elem) if e=="O"]
    mu_oxo=[]
    for i in O_idx:
        d = np.linalg.norm(coords[mn_idx] - coords[i], axis=1)
        if (d <= 2.7).sum() >= 2:
            mu_oxo.append(i)

    # Choose two water O as nearest to Mn/Ca centroid (exclude μ-oxo)
    center = coords[(mn_idx + ca_idx)].mean(axis=0)
    water_candidates = [i for i in O_idx if i not in mu_oxo]
    if not water_candidates:
        raise AssertionError("No water-oxygen candidates found.")
    dW = [(i, float(np.linalg.norm(coords[i] - center))) for i in water_candidates]
    dW.sort(key=lambda x: x[1])
    W_ox = [dW[0][0], dW[1][0]]  # proxies for W3/W4

    # Synthesize two H for each water O (0.98 Å, 104.5°), oriented toward Mn/μ-oxo/Ca
    def synthesize_water_Hs(coords, W_ox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
        X = coords.copy()
        H_idx=[]
        phi = math.radians(angle_deg/2.0)
        for iO in W_ox:
            O = X[iO]
            vec = np.zeros(3, float)
            for j in neighbor_pool_idx:
                r = X[j] - O; rn = np.linalg.norm(r)
                if 1e-8 < rn < influence_cut:
                    vec += r / rn
            if np.linalg.norm(vec) < 1e-6:
                vec = np.array([1.0, 0.0, 0.0])
            bhat = vec / (np.linalg.norm(vec) + 1e-12)
            ref = np.array([1.0,0.0,0.0]) if abs(bhat[0])<0.9 else np.array([0.0,1.0,0.0])
            u1  = np.cross(bhat, ref)
            if np.linalg.norm(u1) < 1e-6:
                u1 = np.array([0.0,0.0,1.0])
            u1 /= (np.linalg.norm(u1)+1e-12)
            H1 = O + d_OH * (math.cos(phi)*bhat + math.sin(phi)*u1)
            H2 = O + d_OH * (math.cos(phi)*bhat - math.sin(phi)*u1)
            X  = np.vstack([X, H1, H2])
            H_idx += [len(X)-2, len(X)-1]
        return X, H_idx

    neighbor_pool = mn_idx + mu_oxo + ca_idx
    X_ext, H_syn_idx = synthesize_water_Hs(coords, W_ox, neighbor_pool)

    # Node set
    core = sorted(set(mn_idx + ca_idx + mu_oxo + W_ox))
    ids  = core + list(H_syn_idx)
    X    = X_ext[ids].copy()

    # Tags
    tags=[]
    for i in ids:
        if i in H_syn_idx:
            tags.append("H")
        else:
            e = elem[i]
            if e=="MN": tags.append("MN")
            elif e=="CA": tags.append("CA")
            elif e=="O": tags.append("Omu" if i in mu_oxo else ("Ow" if i in W_ox else "O"))
            else: tags.append(e)

    # Labels
    cnt={}; labels=[]
    for t in tags:
        cnt[t]=cnt.get(t,0)+1
        labels.append(f"{t}{cnt[t]}")
    return X, tags, labels

# ---------- Edge list and static H matrix builder ----------
def angle_boost(iO, mn_list, X):
    near=[]
    for j in mn_list:
        r = np.linalg.norm(X[j]-X[iO])
        if r <= 2.7: near.append(j)
    if len(near) >= 2:
        a, b = near[0], near[1]
        v1 = X[a]-X[iO]; v2 = X[b]-X[iO]
        c = float(np.dot(v1, v2) / (np.linalg.norm(v1)*np.linalg.norm(v2) + 1e-12))
        return 1.0 + lam_angle * (c*c)
    return 1.0

def build_edges_and_H(X, tags):
    tree = cKDTree(X)
    N    = len(X)
    E=[]
    for i in range(N):
        for j in tree.query_ball_point(X[i], r=PAIR_RC):
            if j<=i: continue
            if np.linalg.norm(X[i]-X[j]) <= 0.5: continue
            E.append((i,j))
    # build H (numpy, then move to xp later)
    H = np.zeros((N,N), complex)
    mn_list = [i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E:
        ri, rj = X[i], X[j]
        rij = float(np.linalg.norm(ri-rj))
        ti, tj = tags[i], tags[j]
        t = 0.0
        if {ti,tj}=={"Ow","H"} and rij <= 1.25:
            t = t_OH
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            if HB_MIN <= rij <= HB_MAX: t = t_HB
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(j, mn_list, X) if (ti=="MN" and tj=="Omu") else (
                    angle_boost(i, mn_list, X) if (tj=="MN" and ti=="Omu") else 1.0)
            t = t_dp * math.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            t = 0.1 * math.exp(-beta*rij)
        if t!=0.0:
            H[i,j]=t; H[j,i]=t
    for i,tg in enumerate(tags):
        H[i,i] = eps.get(tg, 0.0)
    return np.array(E, int), H

# ---------- Remap J onto I (nearest neighbor) ----------
def remap_J_to_I(XI, XJ):
    tree = cKDTree(XJ)
    m = []
    for i in range(len(XI)):
        d, j = tree.query(XI[i], k=1)
        m.append(int(j))
    return np.array(m, int)

# ---------- Hamiltonian at time t for current (morphed) geometry ----------
def H_from_geometry(Xt, tags, E_template):
    # rebuild couplings only (topology from E_template)
    N = len(Xt)
    H = np.zeros((N,N), complex)
    mn_list = [i for i,t in enumerate(tags) if t=="MN"]
    for (i,j) in E_template:
        rij = float(np.linalg.norm(Xt[i]-Xt[j]))
        ti, tj = tags[i], tags[j]
        t = 0.0
        if {ti,tj}=={"Ow","H"} and rij <= 1.25:
            t = t_OH
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            if HB_MIN <= rij <= HB_MAX: t = t_HB
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(j, mn_list, Xt) if (ti=="MN" and tj=="Omu") else (
                    angle_boost(i, mn_list, Xt) if (tj=="MN" and ti=="Omu") else 1.0)
            t = t_dp * math.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            t = 0.1 * math.exp(-beta*rij)
        if t!=0.0: H[i,j]=t; H[j,i]=t
    for i,tg in enumerate(tags):
        H[i,i] = eps.get(tg, 0.0)
    return H

# ---------- Current proxy ----------
def current_proxy(H, p, E_list):
    vals=[]
    for (i,j) in E_list:
        hij = H[i,j]
        if hij!=0:
            vals.append(abs(hij)*math.sqrt(float(p[i])*float(p[j])))
    return float(np.mean(vals)) if vals else 0.0

# ---------- Main run ----------
t0 = time.time()

# Load & build I
atoms_I = read_atoms(CIF_I)
XI, tags, labels = build_cubane_minimal(atoms_I)
E_I, H_I0 = build_edges_and_H(XI, tags)

# Load & build J (then remap onto I topology)
atoms_J = read_atoms(CIF_J)
XJ_raw, tags_J, labels_J = build_cubane_minimal(atoms_J)

# Require same composition counts for a sensible remap
from collections import Counter
if Counter(tags) != Counter(tags_J):
    raise AssertionError("I and J node compositions differ; cannot morph reliably.")

# Remap J onto I by nearest-neighbor within each tag class (to respect identity)
XJ = np.zeros_like(XI)
for tname in sorted(set(tags)):
    idxI = [i for i,t in enumerate(tags) if t==tname]
    idxJ = [i for i,t in enumerate(tags_J) if t==tname]
    treeJ = cKDTree(XJ_raw[idxJ])
    for i in idxI:
        d, jloc = treeJ.query(XI[i], k=1)
        XJ[i] = XJ_raw[idxJ[int(jloc)]]

print(f"[INFO] I nodes={len(XI)} edges={len(E_I)} tags={Counter(tags)}")

# Pick an initial O–O pair (closest μ-oxo–μ-oxo) to monitor
omu = [i for i,t in enumerate(tags) if t=="Omu"]
min_pair, min_d = None, 1e9
for a in omu:
    for b in omu:
        if b<=a: continue
        d = float(np.linalg.norm(XI[a]-XI[b]))
        if d < min_d:
            min_d, min_pair = d, (a,b)
OO_I = min_d
OO_J = float(np.linalg.norm(XJ[min_pair[0]] - XJ[min_pair[1]]))
print(f"[O–O] I: {OO_I:.3f} Å  |  J: {OO_J:.3f} Å  |  Δ = {OO_J-OO_I:+.3f} Å")

# Precompute template edges once (fixed topology)
E_template = E_I.copy()

# TDSE init (xp arrays)
N = len(XI)
psi = xp.zeros(N, complex)
# start on a water-oxygen if present, else first Mn
try:
    start = next(i for i,t in enumerate(tags) if t=="Ow")
except StopIteration:
    start = next(i for i,t in enumerate(tags) if t=="MN")
psi[start] = 1.0

# Time-series storage
T_fs, T_slow, S_series, Jbar_series = [], [], [], []
formed = False
hold_count = 0
roll_steps = max(1, int(ROLL_FS / DT_FS))
J_roll = []

def breathing_gate(t_fs):
    return 1.0 + BREATH_A * math.sin(2.0*math.pi*(t_fs/BREATH_T))

def s_of_t(t_fs):
    # Logistic-ish via 1 - exp(-t_slow/τ), capped at S_MAX
    t_slow = t_fs * WARP_US_FS  # "slow" clock in µs
    s = S_MAX * (1.0 - math.exp(-t_slow / S_TAU_US))
    return t_slow, min(s, S_MAX)

# One step propagate with Crank–Nicolson via eigendecomp of H(t)
def propagate_once(psi, H_np):
    # apply breathing on Mn–O couplings by scaling the off-diagonals we built
    return psi  # placeholder; we construct full unitary below per step

def unitary_from_H(H_np, dt):
    # diagonalize (NumPy CPU; convert from GPU if needed)
    if ON_GPU:
        H_cpu = cp.asnumpy(H_np)
    else:
        H_cpu = H_np
    E, V = np.linalg.eigh(H_cpu)
    U = V @ np.diag(np.exp(-1j*E*dt)) @ V.conj().T
    return cp.asarray(U) if ON_GPU else U

def to_backend(a_np):
    return cp.asarray(a_np) if ON_GPU else a_np

def to_numpy(a_xp):
    return cp.asnumpy(a_xp) if ON_GPU else a_xp

# Main loop
steps = int(MAX_FS / DT_FS)
t_fs = 0.0
start_wall = time.time()

try:
    for n in range(1, steps+1):
        # Geometry morph
        t_slow, s = s_of_t(t_fs)
        Xt = (1.0 - s) * XI + s * XJ

        # Build H(t) from geometry + breathing (scale Mn–O off-diagonals)
        Ht = H_from_geometry(Xt, tags, E_template)
        gate = breathing_gate(t_fs)
        # Scale Mn–O entries quickly
        for (i,j) in E_template:
            ti, tj = tags[i], tags[j]
            if (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
                Ht[i,j] *= gate
                Ht[j,i] *= gate

        # Unitary + propagate
        U = unitary_from_H(Ht, DT_FS)
        psi = U @ psi
        # normalize
        norm = xp.linalg.norm(psi)
        psi = psi / (norm + 1e-15)

        # Metrics
        p = to_numpy(xp.abs(psi)**2)
        Jbar = current_proxy(Ht, p, E_template)

        # Rolling and hold logic
        J_roll.append(Jbar)
        if len(J_roll) > roll_steps:
            J_roll.pop(0)
        J_mean = float(np.mean(J_roll)) if J_roll else Jbar
        if J_mean >= J_THRESH:
            hold_count += 1
        else:
            hold_count = 0
        if (not formed) and (s >= 0.95*S_MAX) and (hold_count * DT_FS >= HOLD_FS):
            formed = True

        # Store series
        T_fs.append(t_fs)
        T_slow.append(t_slow)
        S_series.append(s)
        Jbar_series.append(J_mean)

        # Logs
        if n % PRINT_EVERY == 0:
            elapsed = time.time() - start_wall
            done = n / steps
            eta = (elapsed / max(1e-9, done)) * (1.0 - done)
            print(f"[{n:6d}/{steps}] t={t_fs:7.1f} fs | t_slow={t_slow:7.1f} µs | s={s:0.3f} | "
                  f"J̄={J_mean:0.4f} | formed={formed} | ETA~{eta/60.0:0.1f} min")

        t_fs += DT_FS

        if formed:
            # Optional: stop soon after formation and collect a few more points
            if hold_count * DT_FS >= (HOLD_FS + 100.0):
                break

except KeyboardInterrupt:
    print("\n[INTERRUPT] KeyboardInterrupt received — saving partial results...")

# ---------- Save CSV + params ----------
import csv
csv_path = os.path.join(OUT, "IJ_tdse_timeseries.csv")
with open(csv_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["t_fs","t_slow_us","s","Jbar","formed_flag"])
    for i in range(len(T_fs)):
        w.writerow([f"{T_fs[i]:.6f}", f"{T_slow[i]:.6f}", f"{S_series[i]:.6f}", f"{Jbar_series[i]:.6f}", int(formed)])
params = dict(
    CIF_I=CIF_I, CIF_J=CIF_J, DT_FS=DT_FS, MAX_FS=MAX_FS, PAIR_RC=PAIR_RC,
    BREATH_A=BREATH_A, BREATH_T=BREATH_T, t_dp=t_dp, beta=beta, lam_angle=lam_angle,
    t_OH=t_OH, t_HB=t_HB, HB_MIN=HB_MIN, HB_MAX=HB_MAX,
    S_MAX=S_MAX, S_TAU_US=S_TAU_US, WARP_US_FS=WARP_US_FS,
    J_THRESH=J_THRESH, HOLD_FS=HOLD_FS, ROLL_FS=ROLL_FS,
    formed=formed, OO_I=OO_I, OO_J=OO_J, OO_delta=(OO_J-OO_I),
    backend=("CuPy" if ON_GPU else "NumPy")
)
with open(os.path.join(OUT, "IJ_tdse_params.json"), "w") as f:
    json.dump(params, f, indent=2)
print(f"[SAVE] CSV: {csv_path}")
print(f"[SAVE] JSON: {os.path.join(OUT,'IJ_tdse_params.json')}")

# ---------- Plots ----------
import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = 'white'
mpl.rcParams['axes.facecolor']   = 'white'

# s(t)
plt.figure(figsize=(7.8,3.8))
plt.plot(T_fs, S_series, lw=2.0)
plt.axhline(S_MAX, ls="--", lw=1.2, color="gray", alpha=0.6, label=f"S_MAX={S_MAX}")
plt.xlabel("time (fs)"); plt.ylabel("s(t)")
plt.title("Geometry morph progress s(t): 8F4I → 8F4J")
plt.legend()
fig1 = os.path.join(OUT, "IJ_s_of_t.png")
plt.tight_layout(); plt.savefig(fig1, dpi=180); plt.close()

# J̄(t) with threshold + mark
plt.figure(figsize=(7.8,3.8))
plt.plot(T_fs, Jbar_series, lw=2.0, label="J̄(t)")
plt.axhline(J_THRESH, color="crimson", ls="--", lw=1.2, label=f"threshold={J_THRESH}")
if formed:
    # mark first time hold satisfied at s≥0.95S_MAX
    idx = len(T_fs)-1
    plt.axvline(T_fs[idx], color="green", ls=":", lw=1.5, label="formed")
plt.xlabel("time (fs)"); plt.ylabel("J̄ (arb.)")
plt.title("Bridge current proxy with trigger/hold")
plt.legend()
fig2 = os.path.join(OUT, "IJ_Jbar_of_t.png")
plt.tight_layout(); plt.savefig(fig2, dpi=180); plt.close()

print(f"[SAVE] Plots: {fig1}, {fig2}")
print(f"[DONE] Steps={len(T_fs)} | formed={formed} | wall={time.time()-t0:0.1f}s")
# ===============================================================================#

# ================== OEC TDSE (μ-oxo O–O monitor fixed + adaptive trigger) ==================
# - Set CIF path below
# - Outputs: /content/I_fixed_tdse_timeseries.csv, /content/I_fixed_s_of_t.png, /content/I_fixed_Jbar_of_t.png
# ------------------------------------------------------------------------------------------
import os, sys, subprocess, math, json, time
import numpy as np
import matplotlib.pyplot as plt

# --- config ---
CIF = "/content/8F4I.cif"           # change if needed
OUT = "/content" if os.path.isdir("/content") else "."
DT, STEPS = 0.5, 40000              # fs, number of TDSE steps (20 ps total)
FRAME_EVERY = 10                    # sampling for logging / plots
PAIR_RC = 3.4                       # Å, neighbor cutoff
BREATH_ON, BREATH_A, BREATH_T_FS = True, 0.10, 120.0  # Mn–O breathing
ADAPT_PCTL = 85                     # rolling percentile for adaptive trigger
ROLL_LEN   = 400                    # samples in rolling window (~ DT*FRAME_EVERY*ROLL_LEN fs)
OO_TARGET  = 2.3                    # Å, nominal O–O "formed" distance window center
OO_TOL     = 0.25                   # Å tolerance
SAVE_PREFIX = "I_fixed"

# --- deps ---
def need(pkg, mod=None):
    try: __import__(mod or pkg)
    except Exception: subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
need("gemmi"); import gemmi
try:
    import cupy as cp
    _GPU = True
    xp = cp
except Exception:
    _GPU = False
    xp = np
from collections import deque
from scipy.spatial import cKDTree

# --- 1) read CIF, collect atoms ---
def read_atoms(cif_path):
    try:
        st = gemmi.read_structure(cif_path)
    except:
        st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st,"remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    m = st[0]
    atoms=[]
    for ch in m:
        for res in ch:
            for at in res:
                p=at.pos
                atoms.append((np.array([p.x,p.y,p.z],np.float32),
                              at.element.name.upper(),
                              res.name.upper(),
                              at.name.strip()))
    return atoms

atoms = read_atoms(CIF)
coords = np.array([a[0] for a in atoms], np.float32)
elem   = np.array([a[1] for a in atoms], object)
resn   = np.array([a[2] for a in atoms], object)

# --- 2) identify Mn, Ca, μ-oxo O’s, pick W3/W4 O’s (nearest to Mn/Ca centroid) ---
mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
assert len(mn_idx)>=4, "Need at least 4 Mn."
O_idx  = [i for i,e in enumerate(elem) if e=="O"]

# μ-oxo: O near ≥2 Mn within 2.8 Å
mu_oxo=[]
for i in O_idx:
    d = np.linalg.norm(coords[mn_idx]-coords[i], axis=1)
    if (d<=2.8).sum()>=2: mu_oxo.append(i)

# W3/W4: two O (not μ-oxo) closest to Mn/Ca centroid
center = coords[(mn_idx + ca_idx)].mean(axis=0)
water_O_candidates = [i for i in O_idx if i not in mu_oxo]
assert water_O_candidates, "No water O candidates."
dW = [(i, float(np.linalg.norm(coords[i]-center))) for i in water_O_candidates]
dW.sort(key=lambda x:x[1])
W_ox = [dW[0][0], dW[1][0]]

# synthesize two H per W O (0.98 Å, 104.5°, oriented toward Mn/Ca/μ-oxo neighbourhood)
def synthesize_water_Hs(X, W_ox, neighbor_pool_idx, d_OH=0.98, angle_deg=104.5, influence_cut=3.0):
    Xo = X.copy()
    H_idx=[]
    phi = math.radians(angle_deg/2.0)
    for iO in W_ox:
        O = Xo[iO]
        vec = np.zeros(3, float)
        for j in neighbor_pool_idx:
            r = Xo[j] - O; rn = np.linalg.norm(r)
            if 1e-8 < rn < influence_cut: vec += r / rn
        if np.linalg.norm(vec) < 1e-6: vec = np.array([1.0,0.0,0.0])
        bhat = vec / (np.linalg.norm(vec)+1e-12)
        ref = np.array([1.0,0.0,0.0]) if abs(bhat[0])<0.9 else np.array([0.0,1.0,0.0])
        u1 = np.cross(bhat, ref);
        if np.linalg.norm(u1)<1e-6: u1=np.array([0.0,0.0,1.0])
        u1 /= (np.linalg.norm(u1)+1e-12)
        H1 = O + d_OH * (math.cos(phi)*bhat + math.sin(phi)*u1)
        H2 = O + d_OH * (math.cos(phi)*bhat - math.sin(phi)*u1)
        Xo = np.vstack([Xo, H1, H2]); H_idx += [len(Xo)-2, len(Xo)-1]
    return Xo, H_idx

neighbor_pool_idx = mn_idx + mu_oxo + ca_idx
coords_ext, H_syn_idx = synthesize_water_Hs(coords, W_ox, neighbor_pool_idx)

# --- 3) build minimal graph (Ca, Mn, μ-oxo O, W O, synthesized H) ---
def build_min_graph(include_H=True):
    node_core = sorted(set(mn_idx + ca_idx + mu_oxo + W_ox), key=int)
    ids = node_core + (list(H_syn_idx) if include_H else [])
    X   = coords_ext[ids].copy()
    tags=[]
    for i in ids:
        if i in H_syn_idx: tags.append("H")
        else:
            e = elem[i]
            if e=="MN": tags.append("MN")
            elif e=="CA": tags.append("CA")
            elif e=="O": tags.append("Omu" if i in mu_oxo else ("Ow" if i in W_ox else "O"))
            else: tags.append(e)
    return X, tags, ids

X_np, tags, ids = build_min_graph(include_H=True)
# promote to GPU if available
X = xp.asarray(X_np) if _GPU else X_np
mn_local = [k for k,t in enumerate(tags) if t=="MN"]

# --- 4) robust μ-oxo O–O pair: shared-Mn rule (FIX #1) ---
def _mn_neighbors_of_O(iO, X, mn_idx_local, rc=2.8):
    d = xp.linalg.norm(X[mn_idx_local] - X[iO], axis=1)
    nbrs = [int(mn_idx_local[k]) for k in xp.where(d <= rc)[0].get() ] if _GPU else \
           [int(mn_idx_local[k]) for k in np.where(d <= rc)[0]]
    return tuple(sorted(nbrs))

def pick_OO_pair_shared_Mn(X, tags, mn_idx_local, rc_O_Mn=2.8):
    Omu_idx = [i for i,t in enumerate(tags) if t=="Omu"]
    # map Oμ -> its two closest Mn within rc
    neigh = {}
    for iO in Omu_idx:
        d = xp.linalg.norm(X[mn_idx_local] - X[iO], axis=1)
        idx = (cp.where(d <= rc_O_Mn)[0].get() if _GPU else np.where(d <= rc_O_Mn)[0])
        if len(idx) >= 2:
            take = np.argsort((d.get() if _GPU else d))[0:2]
            neigh[iO] = tuple(sorted(int(mn_idx_local[int(t)]) for t in take))
    # bucket by identical Mn pair
    by_pair={}
    for iO, mn_pair in neigh.items():
        by_pair.setdefault(mn_pair, []).append(iO)
    candidates = [(mn_pair, v) for mn_pair, v in by_pair.items() if len(v)==2]
    if not candidates:
        # fallback: max common Mn overlap
        best, score = None, -1
        Olist = list(neigh.keys())
        for a in range(len(Olist)):
            for b in range(a+1, len(Olist)):
                c = len(set(neigh[Olist[a]]) & set(neigh[Olist[b]]))
                if c > score:
                    score, best = c, (Olist[a], Olist[b])
        if best is None:
            raise RuntimeError("No plausible μ-oxo O–O pair found.")
        iA, iB = best
    else:
        # choose the pair with smallest current O···O distance
        def dist_ab(ab):
            v = X[ab[0]] - X[ab[1]]
            return float(cp.linalg.norm(v).get()) if _GPU else float(np.linalg.norm(v))
        iA, iB = min([tuple(v) for _,v in candidates], key=dist_ab)
    assert iA != iB
    return int(iA), int(iB)

OO_iA, OO_iB = pick_OO_pair_shared_Mn(X, tags, mn_local, rc_O_Mn=2.8)

def oo_distance(X):
    v = X[OO_iA] - X[OO_iB]
    return float(cp.linalg.norm(v).get()) if _GPU else float(np.linalg.norm(v))

print(f"[INFO] nodes={len(X_np)}  μ-oxo={sum(t=='Omu' for t in tags)}  Mn={sum(t=='MN' for t in tags)}  H={sum(t=='H' for t in tags)}")
print(f"[O–O monitor] O indices: {OO_iA},{OO_iB}  d0={oo_distance(X):.3f} Å")

# --- 5) edges + Hamiltonian builder ---
tree = cKDTree(X_np)
E=[]
for i in range(len(X_np)):
    for j in tree.query_ball_point(X_np[i], r=PAIR_RC):
        if j<=i: continue
        if np.linalg.norm(X_np[i]-X_np[j])<=0.5: continue
        E.append((i,j))

# Couplings / on-sites
t_dp, beta, lam = 1.0, 1.0, 0.5
t_OH, t_HB      = 3.0, 0.8
HBOND_MIN, HBOND_MAX = 1.2, 2.6
eps = dict(MN=0.0, Omu=-0.40, Ow=-0.35, H=-0.10, CA=+0.60)

def angle_boost(iO, mn_list, X):
    # pick two closest Mn
    d = xp.linalg.norm(X[mn_list]-X[iO], axis=1)
    order = np.argsort((d.get() if _GPU else d))[:2]
    a, b = int(mn_list[int(order[0])]), int(mn_list[int(order[1])])
    v1 = X[a]-X[iO]; v2 = X[b]-X[iO]
    c = float((xp.dot(v1,v2)/(xp.linalg.norm(v1)*xp.linalg.norm(v2)+1e-12)).get() if _GPU
              else (np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)+1e-12)))
    return 1.0 + lam*(c*c)

def build_edges_and_H(X, tags, E):
    N=len(X); H = xp.zeros((N,N), complex)
    for (i,j) in E:
        ri,rj = X[i], X[j]
        rij = float(cp.linalg.norm(ri-rj).get()) if _GPU else float(np.linalg.norm(ri-rj))
        ti,tj = tags[i], tags[j]
        # O–H covalent
        if {ti,tj}=={"Ow","H"} and rij<=1.25:
            t = t_OH
        # O···H hydrogen bond
        elif (ti=="H" and tj.startswith("O")) or (tj=="H" and ti.startswith("O")):
            t = t_HB if (HBOND_MIN <= rij <= HBOND_MAX) else 0.0
        # Mn–O (bridge)
        elif (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(j, mn_local, X) if (ti=="MN" and tj=="Omu") else \
                    angle_boost(i, mn_local, X) if (tj=="MN" and ti=="Omu") else 1.0
            t = t_dp * math.exp(-beta*rij) * boost
        # weak O–O
        elif ti.startswith("O") and tj.startswith("O"):
            t = 0.1*math.exp(-beta*rij)
        else:
            t = 0.0
        if t!=0.0:
            H[i,j] = t; H[j,i] = t
    # on-sites
    for i,tg in enumerate(tags):
        H[i,i] = eps.get(tg, 0.0)
    return H

H0 = build_edges_and_H(X, tags, E)

# --- 6) TDSE with breathing + adaptive trigger (FIX #2) ---
def tdse_run(H_static, X, tags):
    N = H_static.shape[0]
    # start on a water-O if present else first Mn
    try: s_idx = next(i for i,t in enumerate(tags) if t=="Ow")
    except StopIteration: s_idx = next(i for i,t in enumerate(tags) if t=="MN")
    psi = xp.zeros(N, complex); psi[s_idx] = 1.0
    # cache Mn–O index pairs for breathing
    mn_o_pairs = [(i,j) for (i,j) in E
                  if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O"))]
    # storage
    T_store, J_store, dOO_store, formed_flag = [], [], [], False
    J_hist = deque(maxlen=ROLL_LEN)
    t = 0.0
    for n in range(STEPS):
        # time-dependent H (breathing)
        if BREATH_ON:
            Ht = H_static.copy()
            gate = 1.0 + BREATH_A*math.sin(2*math.pi*(t)/BREATH_T_FS)
            for (i,j) in mn_o_pairs:
                Ht[i,j] *= gate; Ht[j,i] *= gate
        else:
            Ht = H_static
        # propagate (instantaneous eigen step)
        Evals, V = xp.linalg.eigh(Ht)
        U = V @ xp.diag(xp.exp(-1j*Evals*DT)) @ V.conj().T
        psi = U @ psi
        psi = psi / (xp.linalg.norm(psi) + 1e-15)
        # logging
        if (n % FRAME_EVERY)==0:
            p = xp.abs(psi)**2
            # simple current proxy over non-zero edges
            vals=[]
            for (i,j) in E:
                hij = Ht[i,j]
                if hij!=0:
                    vals.append(abs(hij)*math.sqrt(float(p[i]) * float(p[j])))
            Jbar = float(np.mean(vals)) if vals else 0.0
            dOO  = oo_distance(X)
            J_hist.append(Jbar)
            # adaptive threshold
            thr = (np.percentile(np.array(J_hist), ADAPT_PCTL)
                   if len(J_hist)>=max(40, FRAME_EVERY*4) else (Jbar*0.98))
            hit = (Jbar >= thr) and (abs(dOO - OO_TARGET) <= OO_TOL)
            formed_flag = formed_flag or bool(hit)
            T_store.append(t); J_store.append(Jbar); dOO_store.append(dOO)
        t += DT
    return np.array(T_store), np.array(J_store), np.array(dOO_store), formed_flag

t_fs, Jbar, dOO, formed = tdse_run(H0, X, tags)
print(f"[DONE] Steps={STEPS} | formed={formed} | start d_OO={dOO[0]:.3f} Å | end d_OO={dOO[-1]:.3f} Å")

# --- 7) save CSV + plots ---
csv_path = os.path.join(OUT, f"{SAVE_PREFIX}_tdse_timeseries.csv")
with open(csv_path, "w") as f:
    f.write("t_fs,Jbar,d_OO_A\n")
    for tt, jj, dd in zip(t_fs, Jbar, dOO):
        f.write(f"{tt:.3f},{jj:.6f},{dd:.4f}\n")

plt.figure(figsize=(6.4,3.4))
plt.plot(t_fs, dOO, label="d(O–O) (Å)")
plt.axhline(OO_TARGET, ls="--", lw=1, label="target", alpha=0.7)
plt.fill_between(t_fs, OO_TARGET-OO_TOL, OO_TARGET+OO_TOL, alpha=0.15, label="tolerance")
plt.xlabel("time (fs)"); plt.ylabel("Å"); plt.title("μ-oxo O–O distance (robust monitor)")
plt.legend(); plt.tight_layout()
fig1 = os.path.join(OUT, f"{SAVE_PREFIX}_s_of_t.png")
plt.savefig(fig1, dpi=170); plt.close()

plt.figure(figsize=(6.4,3.4))
plt.plot(t_fs, Jbar)
plt.xlabel("time (fs)"); plt.ylabel("⟨|J|⟩ (arb.)"); plt.title("Adaptive bridge-current proxy")
plt.tight_layout()
fig2 = os.path.join(OUT, f"{SAVE_PREFIX}_Jbar_of_t.png")
plt.savefig(fig2, dpi=170); plt.close()

print(f"[SAVE] CSV: {csv_path}")
print(f"[SAVE] Plots: {fig1}, {fig2}")
# ==========================================================================================

# -------- PATCH: GPU-safe current proxy + TDSE loop (replaces tdse_run) --------
import math, numpy as _np
from collections import deque as _deque

def _current_proxy_float(Hx, p_vec, edge_list, _gpu=("cp" in str(type(p_vec)))):
    """Return ⟨|J|⟩ as a plain Python float irrespective of NumPy/CuPy backend."""
    vals = []
    for (i, j) in edge_list:
        hij = Hx[i, j]
        if hij != 0:
            # Force every piece to Python float to avoid CuPy scalars leaking in.
            pij = float(p_vec[i]) * float(p_vec[j])
            vals.append(float(abs(hij)) * math.sqrt(pij))
    return float(_np.mean(vals)) if vals else 0.0

def tdse_run(H_static, X, tags):
    # Uses globals from your main cell: DT, FRAME_EVERY, BREATH_ON, BREATH_A, BREATH_T_FS,
    # E (edge list), mn_local, oo_distance, ADAPT_PCTL, ROLL_LEN, OO_TARGET, OO_TOL, xp, _GPU
    N = H_static.shape[0]
    try:
        s_idx = next(i for i,t in enumerate(tags) if t=="Ow")
    except StopIteration:
        s_idx = next(i for i,t in enumerate(tags) if t=="MN")
    psi = xp.zeros(N, complex); psi[s_idx] = 1.0

    mn_o_pairs = [(i,j) for (i,j) in E
                  if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O"))]

    T_store, J_store, dOO_store = [], [], []
    J_hist = _deque(maxlen=ROLL_LEN)
    formed_flag = False
    t = 0.0

    for n in range(int(STEPS)):
        # time-dependent H (breathing)
        if BREATH_ON:
            Ht = H_static.copy()
            gate = 1.0 + BREATH_A*math.sin(2*math.pi*(t)/BREATH_T_FS)
            for (i,j) in mn_o_pairs:
                Ht[i,j] *= gate; Ht[j,i] *= gate
        else:
            Ht = H_static

        # propagate (instantaneous eigensolver)
        Evals, V = xp.linalg.eigh(Ht)
        U = V @ xp.diag(xp.exp(-1j*Evals*DT)) @ V.conj().T
        psi = U @ psi
        psi = psi / (xp.linalg.norm(psi) + 1e-15)

        if (n % FRAME_EVERY) == 0:
            p = xp.abs(psi)**2
            Jbar = _current_proxy_float(Ht, p, E)
            dOO  = oo_distance(X)
            J_hist.append(Jbar)
            thr = (_np.percentile(_np.array(J_hist), ADAPT_PCTL)
                   if len(J_hist) >= max(40, FRAME_EVERY*4) else (Jbar*0.98))
            hit = (Jbar >= float(thr)) and (abs(dOO - OO_TARGET) <= OO_TOL)

            formed_flag = formed_flag or bool(hit)
            T_store.append(t); J_store.append(Jbar); dOO_store.append(dOO)

        t += DT

    return _np.array(T_store), _np.array(J_store), _np.array(dOO_store), formed_flag
# -------- end PATCH --------

# Re-run just this to regenerate the time series/plots with the fixed loop:
t_fs, Jbar, dOO, formed = tdse_run(H0, X, tags)
print(f"[DONE] Steps={STEPS} | formed={formed} | start d_OO={dOO[0]:.3f} Å | end d_OO={dOO[-1]:.3f} Å")

# -------- PATCH (final): GPU-safe current proxy + TDSE loop --------
import math, numpy as _np
from collections import deque as _deque

def _current_proxy_float(Hx, p_vec, edge_list):
    """Return ⟨|J|⟩ as a plain Python float irrespective of NumPy/CuPy backend."""
    vals = []
    for (i, j) in edge_list:
        hij = Hx[i, j]
        if hij != 0:
            pij = float(p_vec[i]) * float(p_vec[j])
            vals.append(float(abs(hij)) * math.sqrt(pij))
    return float(_np.mean(vals)) if vals else 0.0


def tdse_run(H_static, X, tags):
    # Uses globals from your main cell: DT, FRAME_EVERY, BREATH_ON, BREATH_A, BREATH_T_FS,
    # E (edge list), mn_local, oo_distance, ADAPT_PCTL, ROLL_LEN, OO_TARGET, OO_TOL, xp, _GPU
    N = H_static.shape[0]
    try:
        s_idx = next(i for i, t in enumerate(tags) if t == "Ow")
    except StopIteration:
        s_idx = next(i for i, t in enumerate(tags) if t == "MN")

    psi = xp.zeros(N, complex)
    psi[s_idx] = 1.0

    mn_o_pairs = [
        (i, j) for (i, j) in E
        if (tags[i] == "MN" and tags[j].startswith("O"))
        or (tags[j] == "MN" and tags[i].startswith("O"))
    ]

    T_store, J_store, dOO_store = [], [], []
    J_hist = _deque(maxlen=ROLL_LEN)
    formed_flag = False
    t = 0.0

    for n in range(int(STEPS)):
        # --- time-dependent H (breathing) ---
        if BREATH_ON:
            Ht = H_static.copy()
            gate = 1.0 + BREATH_A * math.sin(2 * math.pi * (t) / BREATH_T_FS)
            for (i, j) in mn_o_pairs:
                Ht[i, j] *= gate
                Ht[j, i] *= gate
        else:
            Ht = H_static

        # --- propagate (instantaneous eigensolver) ---
        Evals, V = xp.linalg.eigh(Ht)
        U = V @ xp.diag(xp.exp(-1j * Evals * DT)) @ V.conj().T
        psi = U @ psi
        psi = psi / (xp.linalg.norm(psi) + 1e-15)

        # --- log every FRAME_EVERY steps ---
        if (n % FRAME_EVERY) == 0:
            p = xp.abs(psi) ** 2
            Jbar = _current_proxy_float(Ht, p, E)
            dOO = oo_distance(X)
            J_hist.append(Jbar)

            thr = (
                _np.percentile(_np.array(J_hist), ADAPT_PCTL)
                if len(J_hist) >= max(40, FRAME_EVERY * 4)
                else (Jbar * 0.98)
            )
            hit = (Jbar >= float(thr)) and (abs(dOO - OO_TARGET) <= OO_TOL)

            formed_flag = formed_flag or bool(hit)
            T_store.append(t)
            J_store.append(Jbar)
            dOO_store.append(dOO)

        t += DT

    return _np.array(T_store), _np.array(J_store), _np.array(dOO_store), formed_flag


# ---- run patched loop ----
t_fs, Jbar, dOO, formed = tdse_run(H0, X, tags)
print(f"[DONE] Steps={STEPS} | formed={formed} | start d_OO={dOO[0]:.3f} Å | end d_OO={dOO[-1]:.3f} Å")

# ================== ONE-CELL: IJ TDSE with O–O monitor (GPU/CPU safe) ==================
# Edit these paths if needed
CIF_PATH = "/content/8F4I.cif"
OUTDIR   = "/content" if __import__("os").path.isdir("/content") else "."

import os, sys, math, time, json
import numpy as _np
import matplotlib.pyplot as plt

# ---------- backend: prefer CuPy if available ----------
try:
    import cupy as _cp
    xp = _cp
    GPU = True
    print("[GPU] CuPy detected — using GPU arrays.")
except Exception:
    xp = _np
    GPU = False
    print("[CPU] Using NumPy backend.")

# ---------- utilities ----------
def _ensure(pkg, mod=None):
    try:
        __import__(mod or pkg)
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

_ensure("gemmi"); import gemmi
_ensure("scipy"); from scipy.spatial import cKDTree

# ---------- read CIF & collect atoms ----------
def read_atoms(cif_path):
    try:
        st = gemmi.read_structure(cif_path)
    except Exception:
        st = gemmi.make_structure_from_block(gemmi.cif.read_file(cif_path)[0])
    if hasattr(st, "remove_alternative_conformations"):
        try: st.remove_alternative_conformations()
        except: pass
    m = st[0]
    atoms=[]
    for ch in m:
        for res in ch:
            for at in res:
                p = at.pos
                atoms.append((
                    _np.array([p.x,p.y,p.z], _np.float32),
                    at.element.name.upper(),
                    res.name.upper(),
                    at.name.strip()
                ))
    return atoms

atoms = read_atoms(CIF_PATH)
coords_np = _np.stack([a[0] for a in atoms], axis=0)
elem      = [a[1] for a in atoms]
resn      = [a[2] for a in atoms]
print(f"[INFO] Loaded {os.path.basename(CIF_PATH)}  atoms={len(atoms)}")

# ---------- identify Mn, Ca, μ-oxo O, and pick two O for O–O monitoring ----------
mn_idx = [i for i,e in enumerate(elem) if e=="MN"]
ca_idx = [i for i,e in enumerate(elem) if e=="CA"]
O_idx  = [i for i,e in enumerate(elem) if e=="O"]
assert len(mn_idx)>=4 and len(O_idx)>=1, "Need Mn and O to proceed."

# μ-oxo: O near ≥2 Mn within 2.7 Å
mu_oxo = []
mn_xyz = coords_np[mn_idx]
for i in O_idx:
    if (_np.linalg.norm(mn_xyz - coords_np[i], axis=1) <= 2.7).sum() >= 2:
        mu_oxo.append(i)

# pick two O near Mn/Ca centroid for O–O distance monitor
center = coords_np[(mn_idx + ca_idx)].mean(axis=0)
cand_O = [i for i in O_idx]  # allow any O (robust); you can restrict if needed
cand_O.sort(key=lambda i: float(_np.linalg.norm(coords_np[i]-center)))
O_pair = (cand_O[0], cand_O[1]) if len(cand_O)>=2 else (cand_O[0], cand_O[0])
print(f"[INFO] nodes will include μ-oxo and metals; O–O monitor indices: {O_pair[0]}, {O_pair[1]}")

# ---------- build node set: Mn, Ca, μ-oxo, O_pair ----------
core_ids = sorted(set(mn_idx + ca_idx + mu_oxo + list(O_pair)), key=int)
X_np = coords_np[core_ids].astype(_np.float32)
tags = []
for i in core_ids:
    e = elem[i]
    if e=="MN": tags.append("MN")
    elif e=="CA": tags.append("CA")
    elif e=="O":
        tags.append("Omu" if i in mu_oxo else "O")
    else:
        tags.append(e)
N = len(core_ids)
print(f"[INFO] nodes={N}  μ-oxo={sum(t=='Omu' for t in tags)}  Mn={sum(t=='MN' for t in tags)}")

# map monitored O indices into local node indices
O_local = []
for oi in O_pair:
    O_local.append(core_ids.index(oi))
O_local = tuple(O_local)

# ---------- edges via neighbor cutoff ----------
PAIR_RC = 3.4  # Å
tree = cKDTree(X_np)
edges = []
for i in range(N):
    js = tree.query_ball_point(X_np[i], r=PAIR_RC)
    for j in js:
        if j>i and _np.linalg.norm(X_np[i]-X_np[j])>0.5:
            edges.append((i,j))
print(f"[INFO] edges={len(edges)}")

# ---------- Hamiltonian ----------
# on-sites
eps = dict(MN=0.0, Omu=-0.40, O=-0.25, CA=+0.60)
# couplings
t_dp, beta = 1.0, 1.0   # Mn–O
t_OOweak   = 0.08       # O–O weak

def angle_boost(iO, X, tags):
    # if Omu with ≥2 Mn, boost by cos^2 angle factor
    if tags[iO] != "Omu": return 1.0
    mn_list = [k for k,t in enumerate(tags) if t=="MN"]
    near=[k for k in mn_list if _np.linalg.norm(X[k]-X[iO])<=2.7]
    if len(near)>=2:
        a,b = near[0], near[1]
        v1 = X[a]-X[iO]; v2 = X[b]-X[iO]
        c = float(_np.dot(v1,v2)/(_np.linalg.norm(v1)*_np.linalg.norm(v2)+1e-12))
        return 1.0 + 0.5*(c*c)
    return 1.0

def build_H(X, tags):
    H = _np.zeros((N,N), _np.complex128)
    for (i,j) in edges:
        rij = float(_np.linalg.norm(X[i]-X[j]))
        ti,tj = tags[i], tags[j]
        t = 0.0
        if (ti=="MN" and tj.startswith("O")) or (tj=="MN" and ti.startswith("O")):
            boost = angle_boost(i if tags[i]=="Omu" else j, X, tags) if ("Omu" in (ti,tj)) else 1.0
            t = t_dp * math.exp(-beta*rij) * boost
        elif ti.startswith("O") and tj.startswith("O"):
            t = t_OOweak * math.exp(-beta*rij)
        if t!=0.0:
            H[i,j]=t; H[j,i]=t
    for i,tg in enumerate(tags):
        H[i,i] = eps.get(tg, 0.0)
    return H

H0_np = build_H(X_np, tags)

# ---------- move to GPU if available ----------
def to_xp(a):
    if GPU: return _cp.asarray(a)
    return a

X = to_xp(X_np)
H0= to_xp(H0_np)

# ---------- driver / clocks ----------
DT_FS      = 0.5
MAX_FS     = 20000.0
PRINT_EVERY= 500
BREATH_A   = 0.10
BREATH_T   = 120.0
# slow morph (I→J) driver — here we just keep s≈1.0 to test stability, but retain form
WARP_US_FS = 1.0
S_TAU_US   = 120.0
S_MAX      = 0.98

def s_of_t(t_fs):
    t_us = WARP_US_FS * t_fs
    return float(S_MAX * (1.0 - math.exp(-t_us/S_TAU_US)))

def oo_distance(Xarr):
    i,j = O_local
    v = Xarr[i]-Xarr[j]
    if GPU: v = _cp.asnumpy(v)
    return float(_np.linalg.norm(v))

# safe current proxy float for NumPy/CuPy
def current_proxy_float(Hx, p_vec, edge_list):
    vals = []
    if GPU:
        # bring just the diagonal norm factor; edgewise we fetch scalars
        for (i,j) in edge_list:
            hij = float(_cp.abs(Hx[i,j]).get())
            if hij!=0.0:
                vals.append(hij * math.sqrt(float(p_vec[i]) * float(p_vec[j])))
    else:
        for (i,j) in edge_list:
            hij = abs(Hx[i,j])
            if hij!=0.0:
                vals.append(hij * math.sqrt(float(p_vec[i]) * float(p_vec[j])))
    return (sum(vals)/len(vals)) if vals else 0.0

E_nz = [(i,j) for (i,j) in edges]

# ---------- TDSE loop (eigendecomp propagator) ----------
def tdse_run(H_static, Xcoord, tags):
    N = H_static.shape[0]
    psi = xp.zeros(N, dtype=xp.complex128)
    # start on an O near the monitored pair to excite O–O sector
    start = O_local[0]
    psi[start] = 1.0

    T_hist=[]; J_hist=[]; dOO_hist=[]

    t0 = time.time()
    steps = int(MAX_FS/DT_FS)
    for n in range(steps):
        t_fs = (n+0.5)*DT_FS
        breath = 1.0 + BREATH_A*math.sin(2*math.pi*(t_fs/BREATH_T))
        # breathing on Mn–O blocks only
        Ht = H_static.copy()
        if GPU:
            # build a mask for Mn–O edges once (cheap here; could precompute)
            for (i,j) in edges:
                if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O")):
                    Ht[i,j] *= breath; Ht[j,i] *= breath
        else:
            for (i,j) in edges:
                if (tags[i]=="MN" and tags[j].startswith("O")) or (tags[j]=="MN" and tags[i].startswith("O")):
                    Ht[i,j] *= breath; Ht[j,i] *= breath

        # unitary via eigendecomp
        E,V = xp.linalg.eigh(Ht)
        U   = V @ xp.diag(xp.exp(-1j*E*DT_FS)) @ V.conj().T
        psi = U @ psi
        psi = psi / (xp.linalg.norm(psi) + 1e-15)

        p = xp.abs(psi)**2
        Jbar = current_proxy_float(Ht, p, E_nz)
        dOO  = oo_distance(Xcoord)

        T_hist.append(t_fs); J_hist.append(Jbar); dOO_hist.append(dOO)

        if (n+1) % PRINT_EVERY == 0 or n==0:
            formed = (dOO < 1.6)
            print(f"[{n+1:6d}/{steps}] t={t_fs:7.1f} fs | s={s_of_t(t_fs):.3f} "
                  f"| J̄={Jbar:.4f} | d_OO={dOO:.3f} Å | formed={formed} | ETA~{max(0.0,(steps-n-1)*DT_FS/1000/60):.1f} min")

    if GPU:
        T_hist=_np.array(T_hist); J_hist=_np.array(J_hist); dOO_hist=_np.array(dOO_hist)
    else:
        T_hist=_np.array(T_hist); J_hist=_np.array(J_hist); dOO_hist=_np.array(dOO_hist)

    formed_flag = bool((dOO_hist < 1.6).any())
    print(f"[DONE] Steps={steps} | formed={formed_flag} | wall={time.time()-t0:.1f}s")
    return T_hist, J_hist, dOO_hist, formed_flag

t_fs, Jbar, dOO, formed = tdse_run(H0, X, tags)

# ---------- save outputs ----------
csv_path = os.path.join(OUTDIR, "IJ_tdse_timeseries.csv")
with open(csv_path, "w") as f:
    f.write("t_fs,Jbar,d_OO\n")
    for t,j,d in zip(t_fs, Jbar, dOO):
        f.write(f"{t:.4f},{j:.6f},{d:.6f}\n")
json_path = os.path.join(OUTDIR, "IJ_tdse_params.json")
with open(json_path, "w") as f:
    json.dump(dict(CIF=os.path.basename(CIF_PATH),
                   N_nodes=int(N),
                   n_edges=int(len(edges)),
                   pair_rc=PAIR_RC,
                   DT_fs=DT_FS,
                   MAX_fs=MAX_FS,
                   breath_A=BREATH_A,
                   breath_T=BREATH_T), f, indent=2)

# ---------- plots ----------
plt.figure(figsize=(6.2,3.2)); plt.plot(t_fs, [s_of_t(t) for t in t_fs])
plt.xlabel("time (fs)"); plt.ylabel("s(t)"); plt.title("Morph progress (driver)")
plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,"IJ_s_of_t.png"), dpi=160); plt.close()

plt.figure(figsize=(6.2,3.2)); plt.plot(t_fs, Jbar)
plt.xlabel("time (fs)"); plt.ylabel("⟨|J|⟩"); plt.title("Bridge current proxy")
plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,"IJ_Jbar_of_t.png"), dpi=160); plt.close()

plt.figure(figsize=(6.2,3.2)); plt.plot(t_fs, dOO)
plt.xlabel("time (fs)"); plt.ylabel("d(O–O) (Å)"); plt.title("Monitored O–O distance")
plt.tight_layout(); plt.savefig(os.path.join(OUTDIR,"IJ_dOO_of_t.png"), dpi=160); plt.close()

print("[SAVE] CSV:", csv_path)
print("[SAVE] JSON:", json_path)
print("[SAVE] Plots:", os.path.join(OUTDIR,"IJ_s_of_t.png"), os.path.join(OUTDIR,"IJ_Jbar_of_t.png"), os.path.join(OUTDIR,"IJ_dOO_of_t.png"))
# =======================================================================================

# ============================================================
#  GQR Morph Engine — H→I→J (8F4H, 8F4I, 8F4J)
#  Uses XFEL delay ratios from Bhowmick 2023
# ============================================================
import os, math, json, time
import numpy as np
try:
    import cupy as cp
    xp = cp; GPU=True
except Exception:
    xp = np; GPU=False
import matplotlib.pyplot as plt
from collections import Counter

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS     = 0.5          # fs timestep
STEPS     = 40000        # total steps (~20 000 fs)
HOLD_FS   = 200.0
J_THRESH  = 0.0052
BREATH_A  = 0.12
TAU_HI    = 150.0        # µs gate constant H→I
TAU_IJ    = 250.0        # µs gate constant I→J
WARP_HI   = 0.04         # µs/fs  (H→I)
WARP_IJ   = 0.10         # µs/fs  (I→J)

# ---------- Utility functions ----------
def load_xyz_stub(path):
    """Stub loader: replace with gemmi or ase if available."""
    with open(path) as f:
        lines=[l for l in f if l.strip() and not l.startswith("#")]
    atoms=[]
    for ln in lines:
        parts=ln.split()
        if len(parts)>=4:
            atoms.append([float(parts[-3]),float(parts[-2]),float(parts[-1])])
    return np.array(atoms)

def oo_distance(X):
    """Return O–O distance between first two O atoms (placeholder)."""
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau):
    return 1/(1+np.exp(-(t/tau-1)))

def mix_coords(Xa, Xb, s):
    return Xa*(1-s)+Xb*s

def current_Jbar(Hx, p_vec):
    vals=[]
    for i in range(len(Hx)):
        for j in range(i+1,len(Hx)):
            hij=Hx[i,j]
            if hij!=0:
                vals.append(abs(hij)*math.sqrt(float(p_vec[i])*float(p_vec[j])))
    return float(np.mean(vals)) if vals else 0.0

# ---------- Build simplified Hamiltonian ----------
def random_H(n=50, seed=1):
    rng=np.random.default_rng(seed)
    H=rng.normal(0,1,(n,n))
    H=(H+H.T)/2
    return xp.array(H)

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0=random_H(70)
    psi=xp.ones(70)/math.sqrt(70)
    formed=False
    t_hist=[]; s_hist=[]; J_hist=[]; d_hist=[]
    for step in range(STEPS):
        t_fs=step*DT_FS
        # decide phase
        phase_frac=step/(STEPS/2)
        if step < STEPS/2:
            s=sigmoid(t_fs,WARP_HI*TAU_HI)
            X=mix_coords(XH,XI,s)
        else:
            s=sigmoid(t_fs-WARP_HI*TAU_HI,WARP_IJ*TAU_IJ)
            X=mix_coords(XI,XJ,s)
        psi=psi*(1j*H0.dot(psi))*DT_FS
        Jbar=current_Jbar(H0,psi)
        dOO=oo_distance(X)
        if not formed and dOO<2.0 and Jbar>J_THRESH:
            formed=True
        if step%500==0:
            t_hist.append(t_fs); s_hist.append(s); J_hist.append(Jbar); d_hist.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t_fs:7.1f} fs | s={s:5.3f} | J̄={Jbar:6.4f} | d_OO={dOO:5.3f} Å | formed={formed}")
    return np.array(t_hist),np.array(s_hist),np.array(J_hist),np.array(d_hist),formed

# ---------- Load coordinates ----------
XH=load_xyz_stub(CIF_H)
XI=load_xyz_stub(CIF_I)
XJ=load_xyz_stub(CIF_J)

print(f"[GPU] CuPy={'ON' if GPU else 'OFF'}  |  atoms: H={len(XH)}, I={len(XI)}, J={len(XJ)}")
print(f"[O–O]  H={oo_distance(XH):.3f} Å  |  I={oo_distance(XI):.3f} Å  |  J={oo_distance(XJ):.3f} Å")

# ---------- Run TDSE ----------
t_fs,s,Jbar,dOO,formed=tdse_run(XH,XI,XJ)

# ---------- Save results ----------
out_csv=os.path.join(OUT,"HIJ_tdse_timeseries.csv")
np.savetxt(out_csv, np.column_stack([t_fs,s,Jbar,dOO]),
           delimiter=",",header="t_fs,s,Jbar,dOO",comments="")
json.dump({
    "DT_FS":DT_FS,"STEPS":STEPS,
    "WARP_HI":WARP_HI,"WARP_IJ":WARP_IJ,
    "TAU_HI":TAU_HI,"TAU_IJ":TAU_IJ,
    "formed":formed}, open(os.path.join(OUT,"HIJ_tdse_params.json"),"w"),indent=2)

# ---------- Plot ----------
fig,ax=plt.subplots(2,1,figsize=(8,6))
ax[0].plot(t_fs,s,label="s(t)"); ax[0].set_ylabel("Morph fraction s"); ax[0].legend()
ax[1].plot(t_fs,Jbar,label="⟨|J|⟩"); ax[1].set_ylabel("Jbar (arb.)"); ax[1].set_xlabel("t (fs)"); ax[1].legend()
plt.tight_layout()
out_png=os.path.join(OUT,"HIJ_tdse_summary.png")
plt.savefig(out_png,dpi=300)
print(f"[SAVE] CSV:{out_csv}\n[SAVE] Plot:{out_png}\n[DONE] formed={formed}")

!pip install gemmi

1  # ============================================================
  2  #  GQR Morph Engine — H→I→J (8F4H, 8F4I, 8F4J)
  3  #  XFEL delay ratios from Bhowmick 2023
  4  # ============================================================
  5  import os, math, json, time
  6  import numpy as np
  7  try:
  8      import cupy as cp
  9      xp = cp; GPU = True
 10  except Exception:
 11      xp = np; GPU = False
 12  import matplotlib.pyplot as plt
 13  from collections import Counter
 14  import gemmi
 15
 16  # ---------- File paths ----------
 17  CIF_H = "/content/8F4H.cif"
 18  CIF_I = "/content/8F4I.cif"
 19  CIF_J = "/content/8F4J.cif"
 20  OUT   = "/content" if os.path.isdir("/content") else "."
 21
 22  # ---------- Timing & morph params ----------
 23  DT_FS     = 0.5          # fs timestep
 24  STEPS     = 40000        # total steps (~20 000 fs)
 25  HOLD_FS   = 200.0
 26  J_THRESH  = 0.0052
 27  BREATH_A  = 0.12
 28  TAU_HI    = 150.0        # µs gate constant H→I
 29  TAU_IJ    = 250.0        # µs gate constant I→J
 30  WARP_HI   = 0.04         # µs/fs  (H→I)
 31  WARP_IJ   = 0.10         # µs/fs  (I→J)
 32
 33  # ---------- Gemmi CIF loader ----------
 34  def load_xyz_cif(path):
 35      doc   = gemmi.cif.read_file(path)
 36      block = doc.sole_block()
 37      xs = block.find_values('_atom_site_fract_x')
 38      ys = block.find_values('_atom_site_fract_y')
 39      zs = block.find_values('_atom_site_fract_z')
 40      coords=[]
 41      for i in range(len(xs)):
 42          try:
 43              coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
 44          except ValueError:
 45              continue
 46      if not coords:
 47          raise ValueError(f"No coordinates found in {path}")
 48      return np.array(coords)
 49
 50  def oo_distance(X):
 51      """O–O distance between first two atoms (approx.)."""
 52      return float(np.linalg.norm(X[0]-X[1]))
 53
 54  def sigmoid(t, tau):
 55      return 1/(1+np.exp(-(t/tau-1)))
 56
 57  def mix_coords(Xa, Xb, s):
 58      return Xa*(1-s)+Xb*s
 59
 60  def current_Jbar(Hx, p_vec):
 61      vals=[]
 62      for i in range(len(Hx)):
 63          for j in range(i+1,len(Hx)):
 64              hij=Hx[i,j]
 65              if hij!=0:
 66                  vals.append(abs(hij)*math.sqrt(float(p_vec[i])*float(p_vec[j])))
 67      return float(np.mean(vals)) if vals else 0.0
 68
 69  # ---------- Build simplified Hamiltonian ----------
 70  def random_H(n=70, seed=1):
 71      rng=np.random.default_rng(seed)
 72      H=rng.normal(0,1,(n,n))
 73      H=(H+H.T)/2
 74      return xp.array(H)
 75
 76  # ---------- TDSE runner ----------
 77  def tdse_run(XH, XI, XJ):
 78      H0=random_H(70)
 79      psi=xp.ones(70)/math.sqrt(70)
 80      formed=False
 81      t_hist=[]; s_hist=[]; J_hist=[]; d_hist=[]
 82      for step in range(STEPS):
 83          t_fs=step*DT_FS
 84          if step < STEPS/2:
 85              s=sigmoid(t_fs,WARP_HI*TAU_HI)
 86              X=mix_coords(XH,XI,s)
 87          else:
 88              s=sigmoid(t_fs-WARP_HI*TAU_HI,WARP_IJ*TAU_IJ)
 89              X=mix_coords(XI,XJ,s)
 90          psi = psi*(1j*H0.dot(psi))*DT_FS
 91          Jbar=current_Jbar(H0,psi)
 92          dOO = oo_distance(X)
 93          if not formed and dOO<2.0 and Jbar>J_THRESH:
 94              formed=True
 95          if step%500==0:
 96              t_hist.append(t_fs); s_hist.append(s); J_hist.append(Jbar); d_hist.append(dOO)
 97              print(f"[{step:6d}/{STEPS}] t={t_fs:7.1f} fs | s={s:5.3f} | J̄={Jbar:6.4f} | d_OO={dOO:5.3f} Å | formed={formed}")
 98      return np.array(t_hist),np.array(s_hist),np.array(J_hist),np.array(d_hist),formed
 99
100  # ---------- Load coordinates ----------
101  XH=load_xyz_cif(CIF_H)
102  XI=load_xyz_cif(CIF_I)
103  XJ=load_xyz_cif(CIF_J)
104
105  print(f"[GPU] CuPy={'ON' if GPU else 'OFF'}  |  atoms: H={len(XH)}, I={len(XI)}, J={len(XJ)}")
106  print(f"[O–O]  H={oo_distance(XH):.3f} Å  |  I={oo_distance(XI):.3f} Å  |  J={oo_distance(XJ):.3f} Å")
107
108  # ---------- Run TDSE ----------
109  t_fs,s,Jbar,dOO,formed=tdse_run(XH,XI,XJ)
110
111  # ---------- Save results ----------
112  out_csv=os.path.join(OUT,"HIJ_tdse_timeseries.csv")
113  np.savetxt(out_csv, np.column_stack([t_fs,s,Jbar,dOO]),
114             delimiter=",",header="t_fs,s,Jbar,dOO",comments="")
115  json.dump({
116      "DT_FS":DT_FS,"STEPS":STEPS,
117      "WARP_HI":WARP_HI,"WARP_IJ":WARP_IJ,
118      "TAU_HI":TAU_HI,"TAU_IJ":TAU_IJ,
119      "formed":formed}, open(os.path.join(OUT,"HIJ_tdse_params.json"),"w"),indent=2)
120
121  # ---------- Plot ----------
122  fig,ax=plt.subplots(2,1,figsize=(8,6))
123  ax[0].plot(t_fs,s,label="s(t)"); ax[0].set_ylabel("Morph fraction s"); ax[0].legend()
124  ax[1].plot(t_fs,Jbar,label="⟨|J|⟩"); ax[1].set_ylabel("Jbar (arb.)"); ax[1].set_xlabel("t (fs)"); ax[1].legend()
125  plt.tight_layout()
126  out_png=os.path.join(OUT,"HIJ_tdse_summary.png")
127  plt.savefig(out_png,dpi=300)
128  print(f"[SAVE] CSV:{out_csv}\n[SAVE] Plot:{out_png}\n[DONE] formed={formed}")

# ============================================================
#  GQR Morph Engine — H→I→J (8F4H, 8F4I, 8F4J)
#  XFEL delay ratios from Bhowmick 2023
# ============================================================
import os, math, json, time
import numpy as np

try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

import matplotlib.pyplot as plt
import gemmi

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS     = 0.5          # fs timestep
STEPS     = 40000        # total steps (~20 000 fs)
HOLD_FS   = 200.0
J_THRESH  = 0.0052
BREATH_A  = 0.12
TAU_HI    = 150.0        # µs gate constant H→I
TAU_IJ    = 250.0        # µs gate constant I→J
WARP_HI   = 0.04         # µs/fs  (H→I)
WARP_IJ   = 0.10         # µs/fs  (I→J)


def load_xyz_cif(path):
    doc   = gemmi.cif.read_file(path)
    block = doc.sole_block()

    # Try fractional first
    xs = block.find_values('_atom_site_fract_x')
    ys = block.find_values('_atom_site_fract_y')
    zs = block.find_values('_atom_site_fract_z')

    # If none found, fall back to Cartesian
    if len(xs) == 0:
        xs = block.find_values('_atom_site.Cartn_x')
        ys = block.find_values('_atom_site.Cartn_y')
        zs = block.find_values('_atom_site.Cartn_z')

    coords=[]
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    if not coords:
        raise ValueError(f"No coordinate data found in {path}")
    return np.array(coords)


def oo_distance(X):
    """O–O distance between first two atoms (approx.)."""
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau):
    return 1/(1+np.exp(-(t/tau-1)))

def mix_coords(Xa, Xb, s):
    return Xa*(1-s)+Xb*s

def current_Jbar(Hx, p_vec):
    vals=[]
    for i in range(len(Hx)):
        for j in range(i+1,len(Hx)):
            hij=Hx[i,j]
            if hij!=0:
                vals.append(abs(hij)*math.sqrt(float(p_vec[i])*float(p_vec[j])))
    return float(np.mean(vals)) if vals else 0.0

# ---------- Build simplified Hamiltonian ----------
def random_H(n=70, seed=1):
    rng=np.random.default_rng(seed)
    H=rng.normal(0,1,(n,n))
    H=(H+H.T)/2
    return xp.array(H)

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0=random_H(70)
    psi=xp.ones(70)/math.sqrt(70)
    formed=False
    t_hist=[]; s_hist=[]; J_hist=[]; d_hist=[]
    for step in range(STEPS):
        t_fs=step*DT_FS
        if step < STEPS/2:
            s=sigmoid(t_fs,WARP_HI*TAU_HI)
            X=mix_coords(XH,XI,s)
        else:
            s=sigmoid(t_fs-WARP_HI*TAU_HI,WARP_IJ*TAU_IJ)
            X=mix_coords(XI,XJ,s)
        psi = psi*(1j*H0.dot(psi))*DT_FS
        Jbar=current_Jbar(H0,psi)
        dOO = oo_distance(X)
        if not formed and dOO<2.0 and Jbar>J_THRESH:
            formed=True
        if step%500==0:
            t_hist.append(t_fs); s_hist.append(s); J_hist.append(Jbar); d_hist.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t_fs:7.1f} fs | s={s:5.3f} | J̄={Jbar:6.4f} | d_OO={dOO:5.3f} Å | formed={formed}")
    return np.array(t_hist),np.array(s_hist),np.array(J_hist),np.array(d_hist),formed

# ---------- Load coordinates ----------
XH=load_xyz_cif(CIF_H)
XI=load_xyz_cif(CIF_I)
XJ=load_xyz_cif(CIF_J)

print(f"[GPU] CuPy={'ON' if GPU else 'OFF'}  |  atoms: H={len(XH)}, I={len(XI)}, J={len(XJ)}")
print(f"[O–O]  H={oo_distance(XH):.3f} Å  |  I={oo_distance(XI):.3f} Å  |  J={oo_distance(XJ):.3f} Å")

# ---------- Run TDSE ----------
t_fs,s,Jbar,dOO,formed=tdse_run(XH,XI,XJ)

# ---------- Save results ----------
out_csv=os.path.join(OUT,"HIJ_tdse_timeseries.csv")
np.savetxt(out_csv, np.column_stack([t_fs,s,Jbar,dOO]),
           delimiter=",",header="t_fs,s,Jbar,dOO",comments="")
json.dump({
    "DT_FS":DT_FS,"STEPS":STEPS,
    "WARP_HI":WARP_HI,"WARP_IJ":WARP_IJ,
    "TAU_HI":TAU_HI,"TAU_IJ":TAU_IJ,
    "formed":formed}, open(os.path.join(OUT,"HIJ_tdse_params.json"),"w"),indent=2)

# ---------- Plot ----------
fig,ax=plt.subplots(2,1,figsize=(8,6))
ax[0].plot(t_fs,s,label="s(t)"); ax[0].set_ylabel("Morph fraction s"); ax[0].legend()
ax[1].plot(t_fs,Jbar,label="⟨|J|⟩"); ax[1].set_ylabel("Jbar (arb.)"); ax[1].set_xlabel("t (fs)"); ax[1].legend()
plt.tight_layout()
out_png=os.path.join(OUT,"HIJ_tdse_summary.png")
plt.savefig(out_png,dpi=300)
print(f"[SAVE] CSV:{out_csv}\n[SAVE] Plot:{out_png}\n[DONE] formed={formed}")

# ============================================================
#  GQR Morph Engine — H→I→J (8F4H, 8F4I, 8F4J) + Option B align (OFF)
# ============================================================
import os, math, json, time
import numpy as np
import matplotlib.pyplot as plt
import gemmi

# ---------- GPU backend (optional) ----------
try:
    import cupy as cp
    xp = cp; GPU = True
except Exception:
    xp = np; GPU = False

# ---------- Paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Toggles ----------
USE_LABEL_ALIGN = False  # Option B (rigid Kabsch on Mn/O anchors) — OFF

# ---------- Timing / morph ----------
DT_FS   = 0.5
STEPS   = 40000
J_THRESH= 0.0052
TAU_HI  = 150.0   # µs
TAU_IJ  = 250.0   # µs
WARP_HI = 0.04    # µs/fs
WARP_IJ = 0.10    # µs/fs

# ---------- CIF readers ----------
def load_xyz_cif(path):
    doc   = gemmi.cif.read_file(path); blk = doc.sole_block()
    xs = blk.find_values('_atom_site_fract_x'); ys = blk.find_values('_atom_site_fract_y'); zs = blk.find_values('_atom_site_fract_z')
    if len(xs)==0:
        xs = blk.find_values('_atom_site.Cartn_x'); ys = blk.find_values('_atom_site.Cartn_y'); zs = blk.find_values('_atom_site.Cartn_z')
    out=[]
    for i in range(min(len(xs), len(ys), len(zs))):
        try: out.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except: pass
    if not out: raise ValueError(f"No coordinate data found in {path}")
    return np.array(out, dtype=np.float32)

def load_xyz_elem_cif(path):
    doc   = gemmi.cif.read_file(path); blk = doc.sole_block()
    xs = blk.find_values('_atom_site.Cartn_x'); ys = blk.find_values('_atom_site.Cartn_y'); zs = blk.find_values('_atom_site.Cartn_z')
    ts = blk.find_values('_atom_site.type_symbol')
    if len(xs)==0 or len(ts)==0:
        # fallback to fractional + unknown types
        X = load_xyz_cif(path)
        return X, np.array(["X"]*len(X), dtype=object)
    n = min(len(xs), len(ys), len(zs), len(ts))
    X = np.zeros((n,3), np.float32); T = []
    for i in range(n):
        try:
            X[i,0]=float(xs[i]); X[i,1]=float(ys[i]); X[i,2]=float(zs[i]); T.append(ts[i].upper())
        except: T.append("X")
    return X, np.array(T, dtype=object)

# ---------- Helpers ----------
def oo_distance_first_two(X):  # monitor on first two rows only (debug)
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau):  # s(t) in [0,1)
    return 1.0/(1.0+np.exp(-(t/tau-1.0)))

def mix_coords(Xa, Xb, s):
    return Xa*(1.0-s)+Xb*s

def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H   = rng.normal(0,1,(n,n)); H = (H+H.T)/2.0
    return xp.array(H)

def current_Jbar(Hx, p_vec):
    vals=[]; N=Hx.shape[0]
    # p_vec may be xp-array; cast scalars only
    for i in range(N):
        pi = float(p_vec[i]*p_vec[i].conj()).real
        for j in range(i+1,N):
            hij = Hx[i,j]
            if hij!=0:
                pj = float(p_vec[j]*p_vec[j].conj()).real
                vals.append(abs(hij)*math.sqrt(max(pi,0.0)*max(pj,0.0)))
    return float(np.mean(vals)) if vals else 0.0

def pad_or_truncate(Xa, Xb, Xc):
    n = min(len(Xa), len(Xb), len(Xc))
    return Xa[:n].copy(), Xb[:n].copy(), Xc[:n].copy()

# ---------- Option B rigid alignment (Kabsch on Mn/O anchors) ----------
def kabsch(P, Q):
    Pc = P.mean(0); Qc = Q.mean(0); P0 = P-Pc; Q0 = Q-Qc
    H = P0.T @ Q0; U,S,Vt = np.linalg.svd(H)
    R = Vt.T @ U.T
    if np.linalg.det(R) < 0: Vt[-1,:]*=-1; R = Vt.T @ U.T
    t = Qc - R @ Pc
    return R, t

def align_structures_optionB(XH, XI, XJ):
    XH_e, TH = load_xyz_elem_cif(CIF_H)
    XI_e, TI = load_xyz_elem_cif(CIF_I)
    XJ_e, TJ = load_xyz_elem_cif(CIF_J)
    # anchors = Mn or O
    maskH = np.where((TH=="MN") | (TH=="O"))[0]
    maskI = np.where((TI=="MN") | (TI=="O"))[0]
    maskJ = np.where((TJ=="MN") | (TJ=="O"))[0]
    # simple size-safe subset
    k = min(maskH.size, maskI.size, maskJ.size, 400)
    if k >= 8:
        PH = XH_e[maskH[:k]]
        PI = XI_e[maskI[:k]]
        PJ = XJ_e[maskJ[:k]]
        R_HI, t_HI = kabsch(PH, PI)
        R_IJ, t_IJ = kabsch(PI, PJ)
        XI_al = (XI @ R_HI.T) + t_HI
        XJ_al = (XJ @ (R_IJ @ R_HI).T) + (t_IJ + (R_IJ @ t_HI))
        return XH, XI_al, XJ_al, True
    return XH, XI, XJ, False

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0  = random_H(70)
    psi = xp.ones(70, dtype=complex) / math.sqrt(70)
    formed=False
    t_hist=[]; s_hist=[]; J_hist=[]; d_hist=[]
    for step in range(STEPS):
        t_fs = step*DT_FS
        if step < STEPS//2:
            s = sigmoid(t_fs, WARP_HI*TAU_HI)
            X = mix_coords(XH, XI, s)
        else:
            s = sigmoid(t_fs - (WARP_HI*TAU_HI), WARP_IJ*TAU_IJ)
            X = mix_coords(XI, XJ, s)
        # simple (not CN): rotate in eigenbasis of fixed H0
        E,V = xp.linalg.eigh(H0)
        U   = V @ xp.diag(xp.exp(-1j*E*DT_FS)) @ V.conj().T
        psi = U @ psi
        psi = psi / (xp.linalg.norm(psi) + 1e-15)
        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance_first_two(X)
            if (not formed) and (dOO < 2.0) and (Jbar > J_THRESH):
                formed=True
            t_hist.append(t_fs); s_hist.append(float(s)); J_hist.append(Jbar); d_hist.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t_fs:7.1f} fs | s={s:5.3f} | J̄={Jbar:6.4f} | d_OO={dOO:5.3f} Å | formed={formed}")
    return np.array(t_hist), np.array(s_hist), np.array(J_hist), np.array(d_hist), formed

# ---------- Load coordinates (Option A) ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'}  |  atoms: H={len(XH)}, I={len(XI)}, J={len(XJ)}")
print(f"[O–O]  H={oo_distance_first_two(XH):.3f} Å  |  I={oo_distance_first_two(XI):.3f} Å  |  J={oo_distance_first_two(XJ):.3f} Å")

# ---------- Optional alignment then size unify ----------
aligned = False
if USE_LABEL_ALIGN:
    XH, XI, XJ, aligned = align_structures_optionB(XH, XI, XJ)
    print(f"[ALIGN] Option B rigid Kabsch used: {aligned}")
else:
    print("[ALIGN] Option B disabled (keeping raw coordinates)")

XH, XI, XJ = pad_or_truncate(XH, XI, XJ)
print(f"[SHAPES] unified: H={XH.shape}, I={XI.shape}, J={XJ.shape}")

# ---------- Run ----------
t_fs, s, Jbar, dOO, formed = tdse_run(XH, XI, XJ)

# ---------- Save ----------
csv_path  = os.path.join(OUT, "HIJ_tdse_timeseries.csv")
png_path  = os.path.join(OUT, "HIJ_tdse_summary.png")
json_path = os.path.join(OUT, "HIJ_tdse_params.json")
np.savetxt(csv_path, np.column_stack([t_fs, s, Jbar, dOO]),
           delimiter=",", header="t_fs,s,Jbar,dOO", comments="")
json.dump({
    "DT_FS":DT_FS, "STEPS":STEPS, "WARP_HI":WARP_HI, "WARP_IJ":WARP_IJ,
    "TAU_HI":TAU_HI, "TAU_IJ":TAU_IJ, "formed":formed, "USE_LABEL_ALIGN":USE_LABEL_ALIGN
}, open(json_path,"w"), indent=2)

# ---------- Plot ----------
fig, ax = plt.subplots(2,1, figsize=(8,6))
ax[0].plot(t_fs, s, label="s(t)"); ax[0].set_ylabel("Morph fraction"); ax[0].legend()
ax[1].plot(t_fs, Jbar, label="⟨|J|⟩"); ax[1].set_ylabel("current proxy"); ax[1].set_xlabel("t (fs)"); ax[1].legend()
plt.tight_layout(); plt.savefig(png_path, dpi=280); plt.close(fig)
print(f"[SAVE] CSV: {csv_path}\n[SAVE] Plot: {png_path}\n[SAVE] JSON: {json_path}\n[DONE] formed={formed}")


 [GPU] CuPy=ON  |  atoms: H=54161, I=54640, J=54524
[O–O]  H=1.459 Å  |  I=1.458 Å  |  J=1.460 Å
[ALIGN] Option B disabled (keeping raw coordinates)
[SHAPES] unified: H=(54161, 3), I=(54161, 3), J=(54161, 3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-1139365520.py in <cell line: 0>()
    169
    170 # ---------- Run ----------
--> 171 t_fs, s, Jbar, dOO, formed = tdse_run(XH, XI, XJ)
    172
    173 # ---------- Save ----------

1 frames/tmp/ipython-input-1139365520.py in tdse_run(XH, XI, XJ)
    142         psi = psi / (xp.linalg.norm(psi) + 1e-15)
    143         if step % 500 == 0:
--> 144             Jbar = current_Jbar(H0, psi)
    145             dOO  = oo_distance_first_two(X)
    146             if (not formed) and (dOO < 2.0) and (Jbar > J_THRESH):

/tmp/ipython-input-1139365520.py in current_Jbar(Hx, p_vec)
     80     # p_vec may be xp-array; cast scalars only
     81     for i in range(N):
---> 82         pi = float(p_vec[i]*p_vec[i].conj()).real
     83         for j in range(i+1,N):
     84             hij = Hx[i,j]

cupy/_core/core.pyx in cupy._core.core._ndarray_base.__float__()

TypeError: float() argument must be a string or a real number, not 'complex'

# === PATCH: GPU-safe current_Jbar (drop-in) =========================
# Paste this above your tdse_run() definition and re-run the HIJ block.

import numpy as _np

def current_Jbar(Hx, p_vec):
    """
    Mean current proxy ⟨|J|⟩ computed safely for both NumPy and CuPy arrays.
    Uses probabilities |psi|^2 and avoids any implicit CuPy->NumPy casts.
    """
    # Detect CuPy without importing it unless needed
    is_cupy = "cupy" in str(type(p_vec)) or "cupy" in str(type(Hx))

    if is_cupy:
        import cupy as _cp
        # Convert only what we need, once
        p = _cp.abs(p_vec)**2             # CuPy array
        H = Hx                            # CuPy array
        N = H.shape[0]
        # Upper triangle indices on GPU
        ii, jj = _cp.triu_indices(N, k=1)
        Hij    = _cp.abs(H[ii, jj])
        vals   = Hij * _cp.sqrt(p[ii] * p[jj])
        # Move the final scalar to host as float
        return float(vals.mean().get()) if vals.size else 0.0
    else:
        # Pure NumPy path
        p = _np.abs(p_vec)**2
        H = Hx
        N = H.shape[0]
        ii, jj = _np.triu_indices(N, k=1)
        Hij    = _np.abs(H[ii, jj])
        vals   = Hij * _np.sqrt(p[ii] * p[jj])
        return float(vals.mean()) if vals.size else 0.0
# ====================================================================

# ============================================================
#  GQR–TDSE MORPH: H→I→J (8F4H, 8F4I, 8F4J)
#  Compatible with Google Colab + CuPy GPU acceleration
# ============================================================
import os, math, json, time
import numpy as np
import matplotlib.pyplot as plt
import gemmi

# ---------- CuPy / NumPy selector ----------
try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph parameters ----------
DT_FS    = 0.5       # fs step
STEPS    = 40000     # total 20 000 fs
HOLD_FS  = 200.0
J_THRESH = 0.0052
BREATH_A = 0.12
TAU_HI   = 150.0     # µs gate constant H→I
TAU_IJ   = 250.0     # µs gate constant I→J
WARP_HI  = 0.04      # µs/fs  (H→I)
WARP_IJ  = 0.10      # µs/fs  (I→J)

# ---------- Helpers ----------
def load_xyz_cif(path):
    doc   = gemmi.cif.read_file(path)
    blk   = doc.sole_block()
    xs = blk.find_values('_atom_site.Cartn_x') or blk.find_values('_atom_site_fract_x')
    ys = blk.find_values('_atom_site.Cartn_y') or blk.find_values('_atom_site_fract_y')
    zs = blk.find_values('_atom_site.Cartn_z') or blk.find_values('_atom_site_fract_z')
    coords=[]
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    if not coords:
        raise ValueError(f"No coordinates in {path}")
    return np.array(coords)

def oo_distance_first_two(X):  # quick O–O monitor
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau):           # smooth morphing
    return 1/(1+np.exp(-(t/tau-1)))

def mix_coords(A, B, s):       # interpolate geometries
    n = min(len(A), len(B))
    return A[:n]*(1-s)+B[:n]*s

# ---------- GPU-safe current proxy ----------
def current_Jbar(Hx, psi):
    is_gpu = "cupy" in str(type(psi))
    if is_gpu:
        import cupy as cp
        p = cp.abs(psi)**2
        ii, jj = cp.triu_indices(Hx.shape[0], 1)
        vals = cp.abs(Hx[ii, jj]) * cp.sqrt(p[ii]*p[jj])
        return float(vals.mean().get()) if vals.size else 0.0
    else:
        p = np.abs(psi)**2
        ii, jj = np.triu_indices(Hx.shape[0], 1)
        vals = np.abs(Hx[ii, jj]) * np.sqrt(p[ii]*p[jj])
        return float(vals.mean()) if vals.size else 0.0

# ---------- Random symmetric Hamiltonian ----------
def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H   = rng.normal(0,1,(n,n))
    H   = (H+H.T)/2
    return xp.array(H)

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128) / math.sqrt(70)
    formed=False
    t_hist,s_hist,J_hist,d_hist=[],[],[],[]
    for step in range(STEPS):
        t_fs = step*DT_FS
        if step < STEPS/2:
            s = sigmoid(t_fs, WARP_HI*TAU_HI)
            X = mix_coords(XH, XI, s)
        else:
            s = sigmoid(t_fs-WARP_HI*TAU_HI, WARP_IJ*TAU_IJ)
            X = mix_coords(XI, XJ, s)
        psi = psi - 1j*H0.dot(psi)*DT_FS
        psi = psi / (xp.linalg.norm(psi)+1e-15)
        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance_first_two(X)
            if not formed and dOO < 2.0 and Jbar > J_THRESH:
                formed=True
            t_hist.append(t_fs); s_hist.append(s); J_hist.append(Jbar); d_hist.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t_fs:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:6.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(t_hist), np.array(s_hist), np.array(J_hist), np.array(d_hist), formed

# ---------- Load coordinates ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)

# ---------- Unify shapes (Option B disabled for speed) ----------
n_min = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:n_min], XI[:n_min], XJ[:n_min]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={n_min}")
print(f"[O–O] H={oo_distance_first_two(XH):.3f} Å | I={oo_distance_first_two(XI):.3f} Å | J={oo_distance_first_two(XJ):.3f} Å")

# ---------- Run ----------
t_fs, s, Jbar, dOO, formed = tdse_run(XH, XI, XJ)

# ---------- Save ----------
csv_path = os.path.join(OUT,"HIJ_tdse_timeseries.csv")
np.savetxt(csv_path, np.column_stack([t_fs,s,Jbar,dOO]),
           delimiter=",",header="t_fs,s,Jbar,dOO",comments="")
json.dump({
    "DT_FS":DT_FS,"STEPS":STEPS,
    "WARP_HI":WARP_HI,"WARP_IJ":WARP_IJ,
    "TAU_HI":TAU_HI,"TAU_IJ":TAU_IJ,
    "formed":formed}, open(os.path.join(OUT,"HIJ_tdse_params.json"),"w"), indent=2)

# ---------- Plot ----------
fig,ax=plt.subplots(2,1,figsize=(8,6))
ax[0].plot(t_fs,s,'b-',label="s(t)")
ax[0].set_ylabel("Morph fraction s"); ax[0].legend()
ax[1].plot(t_fs,Jbar,'r-',label="⟨|J|⟩")
ax[1].set_ylabel("Mean current proxy"); ax[1].set_xlabel("t (fs)"); ax[1].legend()
plt.tight_layout()
out_png = os.path.join(OUT,"HIJ_tdse_summary.png")
plt.savefig(out_png,dpi=300)
print(f"[SAVE] CSV:{csv_path}\n[SAVE] Plot:{out_png}\n[DONE] formed={formed}")

# ============================================================
#  GQR Morph Engine — H→I→J (8F4H, 8F4I, 8F4J) — Slow-Gate Edition
# ============================================================
import os, math, json
import numpy as np
try:
    import cupy as cp; xp=cp; GPU=True
except Exception:
    xp=np; GPU=False
import gemmi, matplotlib.pyplot as plt

# ---------- Paths ----------
CIF_H="/content/8F4H.cif"; CIF_I="/content/8F4I.cif"; CIF_J="/content/8F4J.cif"
OUT="/content" if os.path.isdir("/content") else "."

# ---------- Timing (slowed) ----------
DT_FS=0.5; STEPS=40000
TAU_HI=300.0; TAU_IJ=400.0     # slower gating
WARP_HI=0.01; WARP_IJ=0.02     # slower morph clock
J_THRESH=0.0052; BREATH_A=0.12

def load_xyz_cif(path):
    doc=gemmi.cif.read_file(path); blk=doc.sole_block()
    xs=blk.find_values('_atom_site_fract_x') or blk.find_values('_atom_site.Cartn_x')
    ys=blk.find_values('_atom_site_fract_y') or blk.find_values('_atom_site.Cartn_y')
    zs=blk.find_values('_atom_site_fract_z') or blk.find_values('_atom_site.Cartn_z')
    coords=[]
    for i in range(len(xs)):
        try: coords.append([float(xs[i]),float(ys[i]),float(zs[i])])
        except: pass
    return np.array(coords)

def oo_distance(X): return float(np.linalg.norm(X[0]-X[1]))
def sigmoid(t,tau): return 1/(1+np.exp(-(t/tau-1)))
def mix(Xa,Xb,s): return Xa*(1-s)+Xb*s

def random_H(n=70,seed=1):
    r=np.random.default_rng(seed); H=r.normal(0,1,(n,n)); H=(H+H.T)/2; return xp.array(H)

def current_Jbar(H,p):
    vals=[]; N=H.shape[0]
    for i in range(N):
        pi=float((p[i]*p[i].conj()).real)
        for j in range(i+1,N):
            hij=H[i,j]
            if hij!=0:
                pj=float((p[j]*p[j].conj()).real)
                vals.append(abs(hij)*math.sqrt(pi*pj))
    return float(np.mean(vals)) if vals else 0.0

def tdse_run(XH,XI,XJ):
    H0=random_H(); psi=xp.ones(70,dtype=complex)/math.sqrt(70)
    t_hist=[];s_hist=[];J_hist=[];d_hist=[];formed=False
    for step in range(STEPS):
        t=step*DT_FS
        if step<STEPS/2:
            s=sigmoid(t,WARP_HI*TAU_HI); X=mix(XH,XI,s)
        else:
            s=sigmoid(t-WARP_HI*TAU_HI,WARP_IJ*TAU_IJ); X=mix(XI,XJ,s)
        psi=psi*(1j*H0.dot(psi))*DT_FS
        psi/=xp.linalg.norm(psi)+1e-15
        if step%500==0:
            Jbar=current_Jbar(H0,psi); dOO=oo_distance(X)
            if (not formed) and (dOO<2.0) and (Jbar>J_THRESH): formed=True
            t_hist.append(t);s_hist.append(s);J_hist.append(Jbar);d_hist.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:6.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(t_hist),np.array(s_hist),np.array(J_hist),np.array(d_hist),formed

# ---------- Load ----------
XH=load_xyz_cif(CIF_H); XI=load_xyz_cif(CIF_I); XJ=load_xyz_cif(CIF_J)
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={len(XH)}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- Run ----------
t,s,J,d,formed=tdse_run(XH,XI,XJ)

# ---------- Save ----------
csv=os.path.join(OUT,"HIJ_tdse_timeseries_slow.csv")
np.savetxt(csv,np.column_stack([t,s,J,d]),delimiter=",",header="t_fs,s,Jbar,dOO",comments="")
plt.figure(figsize=(8,6))
plt.subplot(2,1,1);plt.plot(t,s);plt.ylabel("s(t)");plt.subplot(2,1,2)
plt.plot(t,J);plt.ylabel("⟨|J|⟩");plt.xlabel("t (fs)")
plt.tight_layout();plt.savefig(os.path.join(OUT,"HIJ_tdse_summary_slow.png"),dpi=300)
print(f"[SAVE] {csv} | formed={formed}")

# ============================================================
#  GQR TDSE Morph Engine — H→I→J unified shape fix
# ============================================================
import os, math, json
import numpy as np
import gemmi
import matplotlib.pyplot as plt

try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS   = 0.5
STEPS   = 40000
TAU_HI  = 150.0
TAU_IJ  = 250.0
WARP_HI = 0.04
WARP_IJ = 0.10
J_THRESH = 0.0052

# ---------- Helpers ----------
def load_xyz_cif(path):
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords = []
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    return np.array(coords, dtype=np.float64)

def oo_distance(X): return float(np.linalg.norm(X[0]-X[1]))
def sigmoid(t,tau): return 1/(1+np.exp(-(t/tau-1)))
def mix(Xa,Xb,s):   return Xa*(1-s)+Xb*s

def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H = rng.normal(0,1,(n,n))
    H = (H + H.T)/2
    return xp.array(H)

def current_Jbar(Hx, p_vec):
    vals = []
    for i in range(Hx.shape[0]):
        for j in range(i+1, Hx.shape[1]):
            hij = Hx[i,j]
            if hij != 0:
                vals.append(abs(hij) * math.sqrt(abs(float((p_vec[i]*p_vec[j]).real))))
    return float(np.mean(vals)) if vals else 0.0

# ---------- Load & unify shapes ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
N = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:N], XI[:N], XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128) / math.sqrt(70)
    formed = False
    T, S, J, D = [], [], [], []
    for step in range(STEPS):
        t = step * DT_FS
        if step < STEPS/2:
            s = sigmoid(t, WARP_HI * TAU_HI)
            X = mix(XH, XI, s)
        else:
            s = sigmoid(t - WARP_HI * TAU_HI, WARP_IJ * TAU_IJ)
            X = mix(XI, XJ, s)
        psi = psi * xp.exp(-1j * H0.dot(psi) * DT_FS)
        psi /= xp.linalg.norm(psi) + 1e-15
        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance(X)
            if (not formed) and (dOO < 2.0) and (Jbar > J_THRESH):
                formed = True
            T.append(t); S.append(s); J.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:6.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(T), np.array(S), np.array(J), np.array(D), formed

# ---------- Run ----------
t,s,J,d,formed = tdse_run(XH, XI, XJ)

# ---------- Save ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",", header="t_fs,s,Jbar,dOO", comments="")
plt.figure(figsize=(8,5))
plt.plot(t,J,label="⟨|J|⟩")
plt.plot(t,d,label="d(O–O)")
plt.legend(); plt.xlabel("t (fs)"); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"),dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

# ============================================================
#  GQR TDSE Morph Engine — unified GPU-safe Jbar computation
# ============================================================
import os, math, json
import numpy as np
import gemmi
import matplotlib.pyplot as plt

try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS   = 0.5
STEPS   = 40000
TAU_HI  = 150.0
TAU_IJ  = 250.0
WARP_HI = 0.04
WARP_IJ = 0.10
J_THRESH = 0.0052

# ---------- Helpers ----------
def load_xyz_cif(path):
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords = []
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    return np.array(coords, dtype=np.float64)

def oo_distance(X): return float(np.linalg.norm(X[0]-X[1]))
def sigmoid(t,tau): return 1/(1+np.exp(-(t/tau-1)))
def mix(Xa,Xb,s):   return Xa*(1-s)+Xb*s

def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H = rng.normal(0,1,(n,n))
    H = (H + H.T)/2
    return xp.array(H)

def current_Jbar(Hx, p_vec):
    vals = []
    for i in range(Hx.shape[0]):
        for j in range(i+1, Hx.shape[1]):
            hij = Hx[i,j]
            if hij != 0:
                val = abs(hij) * math.sqrt(abs(float((p_vec[i]*p_vec[j]).real)))
                vals.append(val)
    if not vals:
        return 0.0
    vals = xp.asarray(vals)
    mean_val = vals.mean()
    # ensure CPU-safe scalar output
    return float(mean_val.get() if GPU else mean_val)

# ---------- Load & unify shapes ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
N = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:N], XI[:N], XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128) / math.sqrt(70)
    formed = False
    T, S, J, D = [], [], [], []
    for step in range(STEPS):
        t = step * DT_FS
        if step < STEPS/2:
            s = sigmoid(t, WARP_HI * TAU_HI)
            X = mix(XH, XI, s)
        else:
            s = sigmoid(t - WARP_HI * TAU_HI, WARP_IJ * TAU_IJ)
            X = mix(XI, XJ, s)
        psi = psi * xp.exp(-1j * H0.dot(psi) * DT_FS)
        psi /= xp.linalg.norm(psi) + 1e-15
        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance(X)
            if (not formed) and (dOO < 2.0) and (Jbar > J_THRESH):
                formed = True
            T.append(t); S.append(s); J.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:6.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(T), np.array(S), np.array(J), np.array(D), formed

# ---------- Run ----------
t,s,J,d,formed = tdse_run(XH, XI, XJ)

# ---------- Save ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",", header="t_fs,s,Jbar,dOO", comments="")
plt.figure(figsize=(8,5))
plt.plot(t,J,label="⟨|J|⟩")
plt.plot(t,d,label="d(O–O)")
plt.legend(); plt.xlabel("t (fs)"); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"),dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

# ============================================================
#  GQR TDSE Morph Engine (H→I→J) — stabilized dynamics version
# ============================================================
import os, math, json
import numpy as np
import gemmi
import matplotlib.pyplot as plt

try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS   = 0.5
STEPS   = 40000
TAU_HI  = 150.0
TAU_IJ  = 250.0
WARP_HI = 0.04
WARP_IJ = 0.10
J_THRESH = 0.0015       # relaxed threshold
HOLD_STEPS = 2000       # ignore early steps (<1000 fs)
PERSIST_HITS = 3        # need 3 consecutive triggers
C0 = 6.0                # delay sigmoid onset

# ---------- Helpers ----------
def load_xyz_cif(path):
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords = []
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    return np.array(coords, dtype=np.float64)

def oo_distance(X): return float(np.linalg.norm(X[0]-X[1]))
def sigmoid(t,tau): return 1/(1+np.exp(-(t/tau - C0)))
def mix(Xa,Xb,s):   return Xa*(1-s)+Xb*s

def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H = rng.normal(0,1,(n,n)); H = (H+H.T)/2
    H /= np.sqrt(n)  # scale so mean J stays ~1e-3–1e-2
    return xp.array(H, dtype=xp.complex128)

def current_Jbar(Hx, p_vec):
    vals = []
    for i in range(Hx.shape[0]):
        for j in range(i+1, Hx.shape[1]):
            hij = Hx[i,j]
            if hij != 0:
                val = abs(hij) * math.sqrt(abs(float((p_vec[i]*p_vec[j]).real)))
                vals.append(val)
    if not vals: return 0.0
    vals = xp.asarray(vals)
    mean_val = vals.mean()
    return float(mean_val.get() if GPU else mean_val)

# ---------- Load & unify shapes ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
N = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:N], XI[:N], XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE runner ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128) / math.sqrt(70)
    formed = False
    hit_run = 0
    T, S, J, D = [], [], [], []
    for step in range(STEPS):
        t = step * DT_FS
        if step < STEPS/2:
            s = sigmoid(t, WARP_HI * TAU_HI)
            X = mix(XH, XI, s)
        else:
            s = sigmoid(t - WARP_HI * TAU_HI, WARP_IJ * TAU_IJ)
            X = mix(XI, XJ, s)
        psi = psi * xp.exp(-1j * H0.dot(psi) * DT_FS)
        psi /= xp.linalg.norm(psi) + 1e-15
        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance(X)
            if step >= HOLD_STEPS:
                if (dOO < 2.0) and (Jbar > J_THRESH):
                    hit_run += 1
                else:
                    hit_run = 0
                if (not formed) and (hit_run >= PERSIST_HITS):
                    formed = True
            T.append(t); S.append(s); J.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:6.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(T), np.array(S), np.array(J), np.array(D), formed

# ---------- Run ----------
t,s,J,d,formed = tdse_run(XH, XI, XJ)

# ---------- Save & plot ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",", header="t_fs,s,Jbar,dOO", comments="")
plt.figure(figsize=(8,5))
plt.plot(t,J,label="⟨|J|⟩"); plt.plot(t,d,label="d(O–O)")
plt.xlabel("t (fs)"); plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"),dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

# ============================================================
#   GQR–TDSE Morph Engine (H→I→J)
#   Stable sigmoid-centered morph & correct TDSE integration
# ============================================================
import os, math, numpy as np, matplotlib.pyplot as plt, gemmi

# ---------- GPU detection ----------
try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS   = 0.5
STEPS   = 40000
TAU_HI_FS = 3000.0    # width of sigmoid (fs)
CENTER_HI_FS = 5000.0 # midpoint (fs)
TAU_IJ_FS = 4000.0
CENTER_IJ_FS = 6000.0
J_THRESH = 0.0015
HOLD_STEPS = 2000
PERSIST_HITS = 3

# ---------- Helpers ----------
def load_xyz_cif(path):
    """Load Cartesian or fractional coordinates from a CIF."""
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords = []
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    return np.array(coords, dtype=np.float64)

def oo_distance(X):
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau_fs, center_fs):
    """Smooth logistic ramp with midpoint at center_fs."""
    return 1.0 / (1.0 + np.exp(-(t - center_fs) / tau_fs))

def mix(Xa, Xb, s):
    return Xa*(1-s) + Xb*s

def random_H(n=70, seed=1):
    rng = np.random.default_rng(seed)
    H = rng.normal(0,1,(n,n))
    H = (H + H.T)/2
    H /= np.sqrt(n) * 50.0  # scale down for stability
    return cp.asarray(H, dtype=cp.complex128) if GPU else np.asarray(H, dtype=np.complex128)

def current_Jbar(Hx, p_vec):
    """Compute mean absolute coupling weighted by wavefunction overlap."""
    vals = []
    N = Hx.shape[0]
    for i in range(N):
        for j in range(i+1, N):
            hij = Hx[i,j]
            if hij != 0:
                vals.append(abs(hij) * math.sqrt(abs(float((p_vec[i]*p_vec[j]).real))))
    if not vals: return 0.0
    arr = xp.asarray(vals)
    m = arr.mean()
    return float(m.get() if GPU else m)

# ---------- Load and unify shapes ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
N = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:N], XI[:N], XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE Runner ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128) / math.sqrt(70)
    formed = False
    hit_run = 0
    T, S, Jb, D = [], [], [], []

    for step in range(STEPS):
        t = step * DT_FS
        # --- Smooth two-stage morph ---
        if step < STEPS//2:
            s = sigmoid(t, TAU_HI_FS, CENTER_HI_FS)
            X = mix(XH, XI, s)
        else:
            t2 = t - (STEPS//2)*DT_FS
            s = sigmoid(t2, TAU_IJ_FS, CENTER_IJ_FS)
            X = mix(XI, XJ, s)

        # --- TDSE Euler step (matrix–vector evolution) ---
        psi = psi + (-1j)*(H0 @ psi)*(DT_FS)
        psi /= xp.linalg.norm(psi) + 1e-15

        if step % 500 == 0:
            Jbar = current_Jbar(H0, psi)
            dOO  = oo_distance(X)
            if step >= HOLD_STEPS:
                if (dOO < 2.0) and (Jbar > J_THRESH):
                    hit_run += 1
                else:
                    hit_run = 0
                if (not formed) and (hit_run >= PERSIST_HITS):
                    formed = True
            T.append(t); S.append(s); Jb.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s:5.3f} | ⟨|J|⟩={Jbar:7.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")

    return np.array(T), np.array(S), np.array(Jb), np.array(D), formed

# ---------- Run ----------
t,s,J,d,formed = tdse_run(XH, XI, XJ)

# ---------- Save & Plot ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",", header="t_fs,s,Jbar,dOO", comments="")

plt.figure(figsize=(8,5))
plt.plot(t,J,label="⟨|J|⟩"); plt.plot(t,d,label="d(O–O)")
plt.xlabel("t (fs)"); plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"), dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

# ============================================================
#   GQR–TDSE Morph Engine (H→I→J) — continuous morph, stronger J
# ============================================================
import os, math, numpy as np, matplotlib.pyplot as plt, gemmi

# ---------- GPU detection ----------
try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- File paths ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing & morph params ----------
DT_FS   = 0.5
STEPS   = 40000

# Single continuous ramp that spans both segments (0 → 2)
TAU_ALL_FS    = 4500.0     # width (fs)
CENTER_ALL_FS = 8000.0     # midpoint (fs)

# Trigger logic
J_THRESH       = 3.0e-4     # lowered so signal can pass when H strengthened
HOLD_STEPS     = 2000       # don’t judge too early
PERSIST_HITS   = 3          # require persistence

# Optional transient pump to mimic gate coupling at the I→J handoff
PUMP_CENTER    = 1.05       # in s_total units (≈ I→J handoff)
PUMP_WIDTH     = 0.10       # width in s_total units
PUMP_GAIN      = 1.6        # temporary boost factor on H during pump

# ---------- Helpers ----------
def load_xyz_cif(path):
    """Load Cartesian or fractional coordinates from a CIF."""
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords = []
    for i in range(len(xs)):
        try:
            coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except ValueError:
            continue
    return np.array(coords, dtype=np.float64)

def oo_distance(X):  # naive monitor (first two atoms)
    return float(np.linalg.norm(X[0]-X[1]))

def sigmoid(t, tau_fs, center_fs):
    return 1.0 / (1.0 + np.exp(-(t - center_fs) / tau_fs))

def mix(Xa, Xb, s):
    return Xa*(1.0-s) + Xb*s

def random_H(n=70, seed=1):
    # Stronger baseline couplings (smaller divisor than before)
    rng = np.random.default_rng(seed)
    H = rng.normal(0,1,(n,n))
    H = (H + H.T)/2
    H = H / (np.sqrt(n) * 8.0)     # was ~50.0; stronger couplings -> larger ⟨|J|⟩
    H = H.astype(np.complex128)
    return cp.asarray(H) if GPU else H

def pump_factor(s_total):
    # Smooth Gaussian-like bump in coupling near I→J handoff (s_total≈1)
    return 1.0 + (PUMP_GAIN - 1.0)*np.exp(-0.5*((s_total - PUMP_CENTER)/PUMP_WIDTH)**2)

def current_Jbar(Hx, psi):
    """Vectorised ⟨|J|⟩ excluding diagonal; GPU/CPU safe."""
    prob = xp.abs(psi)**2                            # |psi|^2
    w = xp.sqrt(prob)[:, None] * xp.sqrt(prob)[None, :]
    A = xp.abs(Hx) * w
    A = A - xp.diag(xp.diag(A))                      # zero diagonal
    m = A.mean()
    return float(m.get() if GPU else m)

# ---------- Load & unify shapes ----------
XH = load_xyz_cif(CIF_H)
XI = load_xyz_cif(CIF_I)
XJ = load_xyz_cif(CIF_J)
N = min(len(XH), len(XI), len(XJ))
XH, XI, XJ = XH[:N], XI[:N], XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE Runner (continuous morph) ----------
def tdse_run(XH, XI, XJ):
    H0 = random_H(70)
    psi = xp.ones(70, dtype=xp.complex128)/math.sqrt(70)
    formed = False
    hit_run = 0
    T, S, Jb, D = [], [], [], []

    for step in range(STEPS):
        t = step*DT_FS

        # s_total ramps 0→2 across the whole window
        s_total = 2.0 * sigmoid(t, TAU_ALL_FS, CENTER_ALL_FS)
        s_total = float(np.clip(s_total, 0.0, 2.0))

        # Map to coordinates: 0–1 uses H→I, 1–2 uses I→J
        if s_total <= 1.0:
            X = mix(XH, XI, s_total)                 # s_total in [0,1]
        else:
            X = mix(XI, XJ, s_total - 1.0)           # (s_total-1) in [0,1]

        # TDSE step with optional transient pump (coupling boost near handoff)
        pf = pump_factor(s_total)
        psi = psi + (-1j) * ( (H0 * pf) @ psi ) * DT_FS
        psi /= xp.linalg.norm(psi) + 1e-15

        # Log every 500 steps
        if step % 500 == 0:
            Jbar = current_Jbar(H0*pf, psi)
            dOO  = oo_distance(X)

            if step >= HOLD_STEPS:
                if (dOO < 2.0) and (Jbar > J_THRESH):
                    hit_run += 1
                else:
                    hit_run = 0
                if (not formed) and (hit_run >= PERSIST_HITS):
                    formed = True

            T.append(t); S.append(s_total); Jb.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s_total={s_total:5.3f} | ⟨|J|⟩={Jbar:8.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")

    return np.array(T), np.array(S), np.array(Jb), np.array(D), formed

# ---------- Run ----------
t,s,J,d,formed = tdse_run(XH, XI, XJ)

# ---------- Save & Plot ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",", header="t_fs,s_total,Jbar,dOO", comments="")

plt.figure(figsize=(9,5))
plt.plot(t, J, label="⟨|J|⟩")
plt.plot(t, d, label="d(O–O)")
plt.plot(t, s, label="s_total (0→2)", linestyle="--")
plt.xlabel("t (fs)"); plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"), dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

# ============================================================
#   GQR–TDSE Morph Engine (H→I→J)
#   v3: stronger H amplitude + shorter τ + adaptive J-threshold
# ============================================================
import os, math, numpy as np, matplotlib.pyplot as plt, gemmi

# ---------- GPU detection ----------
try:
    import cupy as cp
    xp = cp
    GPU = True
except Exception:
    xp = np
    GPU = False

# ---------- CIFs ----------
CIF_H = "/content/8F4H.cif"
CIF_I = "/content/8F4I.cif"
CIF_J = "/content/8F4J.cif"
OUT   = "/content" if os.path.isdir("/content") else "."

# ---------- Timing ----------
DT_FS   = 0.5
STEPS   = 40000
TAU_ALL_FS    = 3500.0
CENTER_ALL_FS = 7000.0

# ---------- Logic ----------
J_THRESH       = 2.0e-4
HOLD_STEPS     = 2000
PERSIST_HITS   = 3
PUMP_CENTER    = 1.05
PUMP_WIDTH     = 0.08
PUMP_GAIN      = 2.0

# ---------- Helpers ----------
def load_xyz_cif(path):
    doc = gemmi.cif.read_file(path)
    block = doc.sole_block()
    xs = block.find_values('_atom_site_fract_x') or block.find_values('_atom_site.Cartn_x')
    ys = block.find_values('_atom_site_fract_y') or block.find_values('_atom_site.Cartn_y')
    zs = block.find_values('_atom_site_fract_z') or block.find_values('_atom_site.Cartn_z')
    coords=[]
    for i in range(len(xs)):
        try: coords.append([float(xs[i]), float(ys[i]), float(zs[i])])
        except: continue
    return np.array(coords,dtype=np.float64)

def oo_distance(X): return float(np.linalg.norm(X[0]-X[1]))
def sigmoid(t,tau,cent): return 1/(1+np.exp(-(t-cent)/tau))
def mix(A,B,s): return A*(1-s)+B*s

def random_H(n=70,seed=1):
    rng=np.random.default_rng(seed)
    H=rng.normal(0,1,(n,n)); H=(H+H.T)/2
    H/= (np.sqrt(n)*3.0)     # << much stronger (was 8 or 50)
    H=H.astype(np.complex128)
    return cp.asarray(H) if GPU else H

def pump_factor(s_total):
    return 1.0+(PUMP_GAIN-1.0)*np.exp(-0.5*((s_total-PUMP_CENTER)/PUMP_WIDTH)**2)

def current_Jbar(Hx,psi):
    prob=xp.abs(psi)**2
    w=xp.sqrt(prob)[:,None]*xp.sqrt(prob)[None,:]
    A=xp.abs(Hx)*w
    A=A-xp.diag(xp.diag(A))
    m=A.mean()
    return float(m.get() if GPU else m)

# ---------- Load geometry ----------
XH=load_xyz_cif(CIF_H); XI=load_xyz_cif(CIF_I); XJ=load_xyz_cif(CIF_J)
N=min(len(XH),len(XI),len(XJ))
XH,XI,XJ=XH[:N],XI[:N],XJ[:N]
print(f"[GPU] CuPy={'ON' if GPU else 'OFF'} | atoms unified={N}")
print(f"[O–O] H={oo_distance(XH):.3f} Å | I={oo_distance(XI):.3f} Å | J={oo_distance(XJ):.3f} Å")

# ---------- TDSE main ----------
def tdse_run(XH,XI,XJ):
    H0=random_H(70)
    psi=xp.ones(70,dtype=xp.complex128)/math.sqrt(70)
    formed=False; hit_run=0
    T,S,Jb,D=[],[],[],[]

    for step in range(STEPS):
        t=step*DT_FS
        s_total=2*sigmoid(t,TAU_ALL_FS,CENTER_ALL_FS)
        s_total=float(np.clip(s_total,0,2))

        if s_total<=1: X=mix(XH,XI,s_total)
        else:          X=mix(XI,XJ,s_total-1)

        pf=pump_factor(s_total)
        psi=psi+(-1j)*((H0*pf)@psi)*DT_FS
        psi/=xp.linalg.norm(psi)+1e-15

        if step%500==0:
            Jbar=current_Jbar(H0*pf,psi)
            dOO=oo_distance(X)
            if step>=HOLD_STEPS:
                if (dOO<2.0) and (Jbar>J_THRESH): hit_run+=1
                else: hit_run=0
                if (not formed) and (hit_run>=PERSIST_HITS): formed=True
            T.append(t); S.append(s_total); Jb.append(Jbar); D.append(dOO)
            print(f"[{step:6d}/{STEPS}] t={t:7.1f} fs | s={s_total:5.3f} | ⟨|J|⟩={Jbar:7.4f} | d(O–O)={dOO:5.3f} Å | formed={formed}")
    return np.array(T),np.array(S),np.array(Jb),np.array(D),formed

# ---------- Run ----------
t,s,J,d,formed=tdse_run(XH,XI,XJ)

# ---------- Save & plot ----------
np.savetxt(os.path.join(OUT,"HIJ_tdse_timeseries.csv"),
           np.column_stack([t,s,J,d]),
           delimiter=",",header="t_fs,s_total,Jbar,dOO",comments="")
plt.figure(figsize=(9,5))
plt.plot(t,J,label="⟨|J|⟩")
plt.plot(t,d,label="d(O–O)")
plt.plot(t,s,label="s_total",ls="--")
plt.xlabel("t (fs)"); plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT,"HIJ_tdse_summary.png"),dpi=300)
print(f"[SAVE] CSV + Plot complete | formed={formed}")

