In [1]:
"""
LiH QML K-Fold Cross-Validation Comparison
============================================

Compares generalizability of four approaches using k-fold cross-validation:
1. Rotationally Equivariant QML - Uses SO(3) equivariant encoding with Heisenberg observable
2. Non-Equivariant QML - Simple QNN with basic rotations
3. Graph Permutation Equivariant QML - Uses graph-based permutation-symmetric encoding
4. Classical Rotationally Equivariant NN - Classical MLP on pairwise distances (E(3) invariant)

This script evaluates model generalizability by:
- Splitting data into k folds
- Training on k-1 folds, testing on held-out fold
- Repeating for each fold
- Computing variance across folds (measures sensitivity to data splits)

Usage (command line):
    python run_kfold_comparison_lih.py --k_folds 5 --n_epochs 200 --output_dir kfold_results_lih

Usage (Jupyter):
    from run_kfold_comparison_lih import main
    results = main(k_folds=5, n_epochs=200, output_dir='kfold_results_lih')
"""

import pennylane as qml
import numpy as np
import json
import os
import argparse
from datetime import datetime
from sklearn.model_selection import KFold

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

from jax import numpy as jnp
from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')


# =============================================================================
# 1. ROTATIONALLY EQUIVARIANT QML MODEL (SO(3))
# =============================================================================

class RotationallyEquivariantQML:
    """
    Rotationally Equivariant Quantum Machine Learning model for LiH.
    Uses SO(3) equivariant encoding with native PennyLane gates.
    
    Architecture adapted from NH₃ implementation with:
    - Multiple qubits with singlet initialization
    - Learnable head_scale and head_bias for output
    - Proper alpha initialization in [0.5, 1.5] range
    """
    
    def __init__(self, n_qubits=6, depth=6, seed=42):
        self.n_qubits = n_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        
        # Heisenberg observable
        self.observable = (
            qml.PauliX(0) @ qml.PauliX(1)
            + qml.PauliY(0) @ qml.PauliY(1)
            + qml.PauliZ(0) @ qml.PauliZ(1)
        )
        
        self._build_circuit()
        self._init_params()
    
    def _singlet(self, wires):
        """Create singlet state on two qubits."""
        w0, w1 = wires
        qml.Hadamard(wires=w0)
        qml.PauliZ(wires=w0)
        qml.PauliX(wires=w1)
        qml.CNOT(wires=[w0, w1])
    
    def _equivariant_encoding(self, alpha, vec3, wire):
        """SO(3) equivariant encoding using qml.Rot."""
        r = jnp.array(vec3, dtype=jnp.float64)
        norm = jnp.linalg.norm(r) + 1e-12
        n = r / norm
        theta = alpha * norm
        qml.Rot(theta * n[0], theta * n[1], theta * n[2], wires=wire)
    
    def _pair_layer(self, weight, wires):
        """Trainable Heisenberg-like interaction."""
        qml.IsingXX(weight, wires=wires)
        qml.IsingYY(weight, wires=wires)
        qml.IsingZZ(weight, wires=wires)
    
    def _build_circuit(self):
        """Build the quantum circuit."""
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(coords, params):
            """
            coords: (1, 3) - H position relative to Li
            params: {"weights", "alphas", "head_scale", "head_bias"}
            """
            weights = params["weights"]
            alphas = params["alphas"]
            
            # Initialize singlets on pairs of qubits
            for i in range(0, self.n_qubits - 1, 2):
                self._singlet([i, i + 1])
            
            # Initial encoding - all qubits encode the same H position
            for i in range(self.n_qubits):
                self._equivariant_encoding(alphas[i, 0], coords[0], i)
            
            # Variational layers
            for d in range(self.depth):
                qml.Barrier()
                # Even pairs
                for i in range(0, self.n_qubits - 1, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                # Odd pairs
                for i in range(1, self.n_qubits, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                # Re-encoding
                for i in range(self.n_qubits):
                    self._equivariant_encoding(alphas[i, d + 1], coords[0], i)
            
            return qml.expval(self.observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters with proper ranges."""
        np.random.seed(self.seed)
        
        # Weights: small initial values, only first row non-zero initially
        weights = np.zeros((self.n_qubits, self.depth), dtype=np.float64)
        weights[0] = np.random.uniform(0.0, np.pi, size=(self.depth,))
        
        # Alphas: in [0.5, 1.5] range for stable encoding
        alphas = np.random.uniform(0.5, 1.5, size=(self.n_qubits, self.depth + 1))
        
        self.params = {
            "weights": jnp.array(weights),
            "alphas": jnp.array(alphas),
            "head_scale": jnp.array(1.0),
            "head_bias": jnp.array(0.0),
        }
    
    def energy(self, coords, params):
        """Compute energy with head transformation."""
        raw = self.circuit(coords, params)
        return params["head_scale"] * raw + params["head_bias"]
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 2. NON-EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """Non-Equivariant QML model for LiH."""
    
    def __init__(self, n_qubits=4, depth=4, seed=42):
        self.n_qubits = n_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
        self._init_params()
    
    def _build_circuit(self):
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            weights = params["weights"]
            
            # Compute bond length
            dist = jnp.linalg.norm(positions[1] - positions[0])
            
            # Encode distance
            for i in range(self.n_qubits):
                qml.Hadamard(wires=i)
                qml.RY(dist * np.pi, wires=i)
            
            # Variational layers
            for layer in range(self.depth):
                for i in range(self.n_qubits):
                    qml.RX(weights[layer, i, 0], wires=i)
                    qml.RY(weights[layer, i, 1], wires=i)
                    qml.RZ(weights[layer, i, 2], wires=i)
                for i in range(self.n_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                qml.CNOT(wires=[self.n_qubits - 1, 0])
            
            return qml.expval(qml.PauliZ(0))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        np.random.seed(self.seed)
        weights = np.random.uniform(-np.pi, np.pi, (self.depth, self.n_qubits, 3))
        self.params = {
            "weights": jnp.array(weights),
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 3. GRAPH PERMUTATION EQUIVARIANT QML MODEL
# =============================================================================

class GraphPermutationEquivariantQML:
    """
    Graph Permutation Equivariant Quantum Machine Learning model.
    
    Uses permutation-symmetric encoding based on graph structure:
    - Encodes interatomic distances (permutation invariant features)
    - Uses symmetric pooling operations
    - Circuit structure respects graph connectivity
    """
    
    def __init__(self, n_qubits=4, depth=4, seed=42):
        self.n_qubits = n_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
        self._init_params()
    
    def _build_circuit(self):
        """Create the graph permutation equivariant circuit."""
        n_qubits = self.n_qubits
        depth = self.depth
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """Graph-based permutation equivariant circuit."""
            weights = params["weights"]       # (depth, n_qubits, 4)
            edge_weights = params["edge_weights"]  # (depth, num_edges, 2)
            global_weights = params["global_weights"]  # (depth, 3)
            
            # Compute pairwise distances (permutation invariant features)
            dist_LiH = jnp.linalg.norm(positions[1] - positions[0])
            
            # Compute direction cosines for directional information
            direction = (positions[1] - positions[0]) / (dist_LiH + 1e-8)
            
            # Symmetric aggregated features
            center_of_mass = jnp.mean(positions, axis=0)
            spread = jnp.std(positions)
            
            # Feature vector (all permutation invariant/equivariant)
            features = jnp.array([
                dist_LiH,
                spread,
                jnp.linalg.norm(center_of_mass),
                direction[2]  # z-component for force direction
            ])
            
            # === Symmetric Initial State ===
            for i in range(n_qubits):
                qml.Hadamard(wires=i)
            
            # === Graph-based Encoding Layers ===
            for layer in range(depth):
                # Node update
                for i in range(n_qubits):
                    angle_y = weights[layer, i, 0] * features[0] + weights[layer, i, 1] * features[1]
                    angle_z = weights[layer, i, 2] * features[2] + weights[layer, i, 3] * features[3]
                    
                    qml.RY(angle_y, wires=i)
                    qml.RZ(angle_z, wires=i)
                
                # Edge operations (ring topology)
                edges = [(i, (i+1) % n_qubits) for i in range(n_qubits)]
                
                for e_idx, (i, j) in enumerate(edges):
                    edge_angle = edge_weights[layer, e_idx % edge_weights.shape[1], 0] * dist_LiH
                    
                    qml.CNOT(wires=[i, j])
                    qml.RZ(edge_angle, wires=j)
                    qml.CNOT(wires=[i, j])
                
                # Global pooling layer
                global_angle = global_weights[layer, 0] * dist_LiH + global_weights[layer, 1]
                for i in range(n_qubits):
                    qml.RY(global_angle * global_weights[layer, 2], wires=i)
            
            # === Permutation Symmetric Measurement ===
            obs = sum(qml.PauliZ(i) for i in range(n_qubits))
            return qml.expval(obs)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters with Xavier-like initialization."""
        np.random.seed(self.seed)
        
        num_edges = self.n_qubits
        
        limit = np.sqrt(2.0 / (self.n_qubits + 4))
        weights = np.random.uniform(-limit, limit, (self.depth, self.n_qubits, 4))
        edge_weights = np.random.uniform(-0.5, 0.5, (self.depth, num_edges, 2))
        global_weights = np.random.uniform(-0.3, 0.3, (self.depth, 3))
        
        self.params = {
            "weights": jnp.array(weights),
            "edge_weights": jnp.array(edge_weights),
            "global_weights": jnp.array(global_weights)
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 4. CLASSICAL ROTATIONALLY EQUIVARIANT NN (E(3) INVARIANT)
# =============================================================================

class ClassicalRotationallyEquivariantNN:
    """
    Classical Rotationally Equivariant Neural Network for LiH.
    Uses physics-inspired E(3) invariant features with smooth activations.
    
    Key improvements over basic version:
    - SiLU activation (smooth for autodiff force computation)
    - Multiple physics-inspired features (not just raw distance)
    - Larger network with skip connections
    - Two-phase training with force warmup
    """
    
    def __init__(self, hidden_dims=[128, 128, 64], seed=42):
        self.hidden_dims = hidden_dims
        self.seed = seed
        
        # Physics parameters for LiH
        self.r_eq = 1.6  # Equilibrium Li-H distance in Å
        self.morse_alpha = 2.0
        
        # RBF parameters
        self.rbf_centers = jnp.linspace(0.8, 3.0, 8)  # 8 Gaussians
        self.rbf_width = 0.3
        
        # Number of features: distance + 1/r + Morse + 8 RBF = 11
        self.n_features = 11
        
        self._init_params()
        self._create_model()
    
    def _init_params(self):
        """Initialize MLP parameters with Xavier initialization."""
        np.random.seed(self.seed)
        
        # Feature dimension -> hidden -> output
        layer_sizes = [self.n_features] + self.hidden_dims + [1]
        
        params = {"weights": [], "biases": []}
        
        for i in range(len(layer_sizes) - 1):
            fan_in = layer_sizes[i]
            fan_out = layer_sizes[i + 1]
            limit = np.sqrt(6.0 / (fan_in + fan_out))
            
            W = np.random.uniform(-limit, limit, (fan_in, fan_out))
            b = np.zeros(fan_out)
            
            params["weights"].append(jnp.array(W))
            params["biases"].append(jnp.array(b))
        
        # Skip connection weights (from input features to last hidden layer)
        skip_dim = layer_sizes[-2]  # Last hidden layer dimension
        params["skip_weight"] = jnp.array(
            np.random.uniform(-0.1, 0.1, (self.n_features, skip_dim))
        )
        
        self.params = params
    
    def _create_model(self):
        """Create the forward pass function with physics-inspired features."""
        
        def compute_features(positions):
            """Compute physics-inspired invariant features from positions."""
            # Li at index 0, H at index 1
            r_vec = positions[1] - positions[0]
            r = jnp.linalg.norm(r_vec) + 1e-12
            
            # Feature 1: Normalized distance
            f_dist = r / 2.0  # Normalize by typical scale
            
            # Feature 2: Inverse distance (Coulomb-like)
            f_inv = 1.0 / r
            
            # Feature 3: Morse-like term
            f_morse = jnp.exp(-self.morse_alpha * (r - self.r_eq))
            
            # Features 4-11: RBF encoding (8 Gaussians)
            f_rbf = jnp.exp(-((r - self.rbf_centers) ** 2) / (2 * self.rbf_width ** 2))
            
            # Concatenate all features
            features = jnp.concatenate([
                jnp.array([f_dist, f_inv, f_morse]),
                f_rbf
            ])
            
            return features
        
        def mlp_forward(x, params):
            """MLP forward pass with SiLU activation and skip connection."""
            weights = params["weights"]
            biases = params["biases"]
            skip_weight = params["skip_weight"]
            
            h = x
            for i in range(len(weights) - 1):
                h = jnp.dot(h, weights[i]) + biases[i]
                # SiLU activation: x * sigmoid(x) - smooth for autodiff!
                h = h * jax.nn.sigmoid(h)
                
                # Add skip connection to last hidden layer
                if i == len(weights) - 2:
                    h = h + 0.1 * jnp.dot(x, skip_weight)
            
            # Output layer (no activation)
            h = jnp.dot(h, weights[-1]) + biases[-1]
            return h.squeeze(-1)
        
        def energy_from_positions(positions, params):
            """Compute energy from atomic positions."""
            features = compute_features(positions)
            energy = mlp_forward(features, params)
            return energy
        
        def force_from_positions(positions, params):
            """Compute forces as negative gradient of energy."""
            grad_fn = jax.grad(energy_from_positions, argnums=0)
            return -grad_fn(positions, params)
        
        self.compute_features = compute_features
        self.energy_fn = energy_from_positions
        self.force_fn = force_from_positions
        self.vec_energy = jax.vmap(energy_from_positions, (0, None), 0)
        self.vec_force = jax.vmap(force_from_positions, (0, None), 0)
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_rotationally_equivariant(model, data_train, E_train, F_train, n_epochs=200, 
                                    lr=3e-3, wE=1.0, wF_max=5.0, warmup_frac=0.4):
    """Train the rotationally equivariant QML model with force warmup curriculum."""
    warmup_epochs = int(n_epochs * warmup_frac)
    
    def raw_energy(coords, params):
        """Raw circuit output."""
        return model.circuit(coords, params)
    
    vec_raw_energy = jax.vmap(raw_energy, (0, None), 0)
    
    def vec_force_fn(coords_batch, params):
        """Compute forces as -grad(energy)."""
        def single_force(coords):
            grad_fn = jax.grad(raw_energy, argnums=0)
            return -grad_fn(coords, params)
        return jax.vmap(single_force)(coords_batch)
    
    @jax.jit
    def loss_fn(params, coords, E_target, F_target, wF):
        # Energy with head transformation
        E_raw = vec_raw_energy(coords, params)
        E_pred = params["head_scale"] * E_raw + params["head_bias"]
        L_E = jnp.mean((E_pred - E_target) ** 2)
        
        # Forces (scaled by head_scale)
        F_raw = vec_force_fn(coords, params)
        F_pred = params["head_scale"] * F_raw
        F_pred_z = F_pred[:, 0, 2]  # H atom z-component
        L_F = jnp.mean((F_pred_z - F_target) ** 2)
        
        # Handle NaNs
        L_E = jnp.where(jnp.isnan(L_E), 1.0, L_E)
        L_F = jnp.where(jnp.isnan(L_F), 1.0, L_F)
        
        return wE * L_E + wF * L_F, (L_E, L_F)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        # Warmup curriculum: gradually increase force weight
        if epoch < warmup_epochs:
            wF = wF_max * (epoch / warmup_epochs)
        else:
            wF = wF_max
        
        (loss, (L_E, L_F)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), data_train, E_train, F_train, wF
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(L_E))
            history["test_F_loss"].append(float(L_F))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, positions, E_train, F_train, n_epochs=200, 
                          lr=0.01, lambda_E=2.0, lambda_F=1.0):
    """Train the non-equivariant QML model."""
    
    def energy_single(pos, params):
        return model.circuit(pos, params)
    
    def force_single(pos, params):
        return -jax.grad(energy_single, argnums=0)(pos, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def loss_fn(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred = vec_force(positions, params)
        F_pred_z = F_pred[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        return lambda_E * E_loss + lambda_F * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_graph_permutation(model, positions, E_train, F_train, n_epochs=200, 
                            lr=0.01, lambda_E=1.5, lambda_F=1.5):
    """Train the graph permutation equivariant QML model."""
    
    def energy_single(pos, params):
        return model.circuit(pos, params)
    
    def force_single(pos, params):
        return -jax.grad(energy_single, argnums=0)(pos, params)
    
    vec_force = jax.vmap(force_single, (0, None), 0)
    
    @jax.jit
    def loss_fn(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred = vec_force(positions, params)
        F_pred_z = F_pred[:, 1, 2]
        F_loss = jnp.mean((F_pred_z - F_target) ** 2)
        
        # Handle NaNs
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return lambda_E * E_loss + lambda_F * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_classical_equivariant(model, positions, E_train, F_train, n_epochs=200, 
                                 lr=3e-3, wE=1.0, wF_max=2.0, warmup_frac=0.3):
    """
    Train the classical equivariant NN model with two-phase training.
    
    Phase 1 (warmup): Energy-only training to establish good features
    Phase 2: Combined energy + forces with gradual force weight ramp
    """
    warmup_epochs = int(n_epochs * warmup_frac)
    
    def huber_loss(pred, target, delta=0.5):
        """Huber loss - robust to outliers."""
        diff = pred - target
        abs_diff = jnp.abs(diff)
        return jnp.mean(jnp.where(abs_diff <= delta, 
                                   0.5 * diff**2, 
                                   delta * (abs_diff - 0.5 * delta)))
    
    @jax.jit
    def loss_fn(params, positions, E_target, F_target, wF):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred = model.vec_force(positions, params)
        F_pred_z = F_pred[:, 1, 2]  # H atom z-component
        F_loss = huber_loss(F_pred_z, F_target, delta=0.5)
        
        # Handle NaNs
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return wE * E_loss + wF * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        # Two-phase training:
        # Phase 1 (warmup): Energy only (wF = 0)
        # Phase 2: Gradually increase force weight
        if epoch < warmup_epochs:
            wF = 0.0  # Energy-only warmup
        else:
            # Ramp force weight from 0 to wF_max over first half of phase 2
            phase2_progress = (epoch - warmup_epochs) / max(1, (n_epochs - warmup_epochs) / 2)
            wF = min(wF_max, wF_max * phase2_progress)
        
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train, wF
        )
        
        # Gradient clipping
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# EVALUATION FUNCTION
# =============================================================================

def evaluate_fold(E_pred_train, E_pred_test, F_pred_train, F_pred_test,
                  E_train_true, E_test_true, F_train_true, F_test_true,
                  energy_scaler, force_scaler):
    """Evaluate predictions on a single fold with post-correction."""
    
    # Post-correction for energy (quadratic fit)
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        popt_E, _ = curve_fit(corr_E, E_pred_train, E_train_true, maxfev=5000)
        E_pred_test_corr = corr_E(E_pred_test, *popt_E)
    except:
        E_pred_test_corr = E_pred_test
    
    # Post-correction for forces (linear fit)
    try:
        lr_model = LinearRegression()
        lr_model.fit(F_pred_train.reshape(-1, 1), F_train_true.reshape(-1, 1))
        F_pred_test_corr = lr_model.predict(F_pred_test.reshape(-1, 1)).flatten()
    except:
        F_pred_test_corr = F_pred_test
    
    # Inverse transform to original units
    E_pred_final = energy_scaler.inverse_transform(E_pred_test_corr.reshape(-1, 1)).flatten()
    E_true_final = energy_scaler.inverse_transform(E_test_true.reshape(-1, 1)).flatten()
    
    F_pred_final = force_scaler.inverse_transform(F_pred_test_corr.reshape(-1, 1)).flatten()
    F_true_final = force_scaler.inverse_transform(F_test_true.reshape(-1, 1)).flatten()
    
    # Compute metrics
    E_mae = np.mean(np.abs(E_pred_final - E_true_final))
    E_rmse = np.sqrt(np.mean((E_pred_final - E_true_final)**2))
    ss_res_E = np.sum((E_pred_final - E_true_final)**2)
    ss_tot_E = np.sum((E_true_final - E_true_final.mean())**2)
    E_r2 = 1 - ss_res_E / ss_tot_E if ss_tot_E > 0 else 0
    
    F_mae = np.mean(np.abs(F_pred_final - F_true_final))
    F_rmse = np.sqrt(np.mean((F_pred_final - F_true_final)**2))
    ss_res_F = np.sum((F_pred_final - F_true_final)**2)
    ss_tot_F = np.sum((F_true_final - F_true_final.mean())**2)
    F_r2 = 1 - ss_res_F / ss_tot_F if ss_tot_F > 0 else 0
    
    metrics = {
        "E_r2": float(E_r2),
        "E_mae_Ha": float(E_mae),
        "E_rmse_Ha": float(E_rmse),
        "E_mae_eV": float(E_mae * 27.2114),
        "F_r2": float(F_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
    }
    
    predictions = {
        "E_pred": E_pred_final.tolist(),
        "E_true": E_true_final.tolist(),
        "F_pred": F_pred_final.tolist(),
        "F_true": F_true_final.tolist(),
    }
    
    return metrics, predictions


# =============================================================================
# K-FOLD CROSS-VALIDATION FUNCTION
# =============================================================================

def run_kfold_cv(k_folds, n_epochs, data_dir, output_dir, seed=42):
    """
    Run k-fold cross-validation for all four methods on LiH.
    
    Args:
        k_folds: Number of folds for cross-validation
        n_epochs: Number of training epochs per fold
        data_dir: Directory containing LiH dataset
        output_dir: Directory to save results
        seed: Random seed for reproducibility
    
    Returns:
        Dictionary with all results
    """
    print(f"\n{'='*80}")
    print(f"LiH K-FOLD CROSS-VALIDATION COMPARISON")
    print(f"{'='*80}")
    print(f"K-Folds: {k_folds}")
    print(f"Epochs per fold: {n_epochs}")
    print(f"Data directory: {data_dir}")
    print(f"Output directory: {output_dir}")
    print(f"{'='*80}\n")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    print("Loading LiH dataset...")
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    
    N_samples = len(energy)
    print(f"  Loaded {N_samples} samples")
    print(f"  Positions shape: {positions.shape}")
    print(f"  Energy shape: {energy.shape}")
    print(f"  Forces shape: {forces.shape}")
    
    # Prepare scalers (fit on entire dataset for consistency)
    energy_scaler = MinMaxScaler((-1, 1))
    if energy.ndim == 1:
        energy = energy.reshape(-1, 1)
    energy_scaled = energy_scaler.fit_transform(energy).flatten()
    
    # Forces: H atom z-component
    forces_H = forces[:, 1:, :]
    force_scaler = MinMaxScaler((-1, 1))
    forces_z = forces_H[:, 0, 2].reshape(-1, 1)
    forces_z_scaled = force_scaler.fit_transform(forces_z).flatten()
    
    # Centered positions for equivariant model (H relative to Li)
    positions_centered = np.zeros((N_samples, 1, 3))
    positions_centered[:, 0, :] = positions[:, 1, :] - positions[:, 0, :]
    
    # Initialize k-fold
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=seed)
    
    # Results storage
    results = {
        "config": {
            "k_folds": k_folds,
            "n_epochs": n_epochs,
            "n_samples": N_samples,
            "seed": seed,
            "timestamp": datetime.now().isoformat(),
        },
        "rotationally_equivariant": {"folds": [], "summary": {}},
        "non_equivariant": {"folds": [], "summary": {}},
        "graph_permutation_equivariant": {"folds": [], "summary": {}},
        "classical_equivariant": {"folds": [], "summary": {}},
    }
    
    method_names = [
        "rotationally_equivariant",
        "non_equivariant", 
        "graph_permutation_equivariant",
        "classical_equivariant"
    ]
    
    # Run k-fold CV
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(positions)):
        print(f"\n{'='*60}")
        print(f"FOLD {fold_idx + 1}/{k_folds}")
        print(f"{'='*60}")
        print(f"  Train samples: {len(train_idx)}, Test samples: {len(test_idx)}")
        
        fold_seed = seed + fold_idx * 100
        
        # Prepare fold data
        E_train = energy_scaled[train_idx]
        E_test = energy_scaled[test_idx]
        F_train = forces_z_scaled[train_idx]
        F_test = forces_z_scaled[test_idx]
        
        # Centered positions for equivariant model
        pos_centered_train = jnp.array(positions_centered[train_idx])
        pos_centered_test = jnp.array(positions_centered[test_idx])
        
        # Raw positions for other models
        pos_raw_train = jnp.array(positions[train_idx])
        pos_raw_test = jnp.array(positions[test_idx])
        
        # ==================== 1. Rotationally Equivariant ====================
        print(f"\n  [1/4] Rotationally Equivariant QML...")
        rot_model = RotationallyEquivariantQML(n_qubits=6, depth=6, seed=fold_seed)
        
        rot_history = train_rotationally_equivariant(
            rot_model, pos_centered_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=3e-3, wE=1.0, wF_max=5.0, warmup_frac=0.4
        )
        
        # Evaluate with head transformation
        params = rot_model.params
        E_raw_train = np.array(rot_model.vec_circuit(pos_centered_train, params))
        E_raw_test = np.array(rot_model.vec_circuit(pos_centered_test, params))
        E_pred_train_rot = float(params["head_scale"]) * E_raw_train + float(params["head_bias"])
        E_pred_test_rot = float(params["head_scale"]) * E_raw_test + float(params["head_bias"])
        
        def raw_energy(coords, params):
            return rot_model.circuit(coords, params)
        def force_single(coords, params):
            return -jax.grad(raw_energy, argnums=0)(coords, params)
        vec_force_rot = jax.vmap(force_single, (0, None), 0)
        
        F_raw_train = np.array(vec_force_rot(pos_centered_train, params))[:, 0, 2]
        F_raw_test = np.array(vec_force_rot(pos_centered_test, params))[:, 0, 2]
        F_pred_train_rot = float(params["head_scale"]) * F_raw_train
        F_pred_test_rot = float(params["head_scale"]) * F_raw_test
        
        rot_metrics, rot_preds = evaluate_fold(
            E_pred_train_rot, E_pred_test_rot, F_pred_train_rot, F_pred_test_rot,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["rotationally_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": rot_metrics,
            "history": rot_history,
        })
        print(f"    Energy R²: {rot_metrics['E_r2']:.4f}, Force R²: {rot_metrics['F_r2']:.4f}")
        
        # ==================== 2. Non-Equivariant ====================
        print(f"\n  [2/4] Non-Equivariant QML...")
        non_eq_model = NonEquivariantQML(n_qubits=4, depth=4, seed=fold_seed)
        
        non_eq_history = train_non_equivariant(
            non_eq_model, pos_raw_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=0.01, lambda_E=2.0, lambda_F=1.0
        )
        
        # Evaluate
        E_pred_train_ne = np.array(non_eq_model.vec_circuit(pos_raw_train, non_eq_model.params))
        E_pred_test_ne = np.array(non_eq_model.vec_circuit(pos_raw_test, non_eq_model.params))
        
        def energy_single_ne(pos, params):
            return non_eq_model.circuit(pos, params)
        def force_single_ne(pos, params):
            return -jax.grad(energy_single_ne, argnums=0)(pos, params)
        vec_force_ne = jax.vmap(force_single_ne, (0, None), 0)
        
        F_pred_train_ne = np.array(vec_force_ne(pos_raw_train, non_eq_model.params))[:, 1, 2]
        F_pred_test_ne = np.array(vec_force_ne(pos_raw_test, non_eq_model.params))[:, 1, 2]
        
        non_eq_metrics, non_eq_preds = evaluate_fold(
            E_pred_train_ne, E_pred_test_ne, F_pred_train_ne, F_pred_test_ne,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["non_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": non_eq_metrics,
            "history": non_eq_history,
        })
        print(f"    Energy R²: {non_eq_metrics['E_r2']:.4f}, Force R²: {non_eq_metrics['F_r2']:.4f}")
        
        # ==================== 3. Graph Permutation Equivariant ====================
        print(f"\n  [3/4] Graph Permutation Equivariant QML...")
        graph_model = GraphPermutationEquivariantQML(n_qubits=4, depth=4, seed=fold_seed)
        
        graph_history = train_graph_permutation(
            graph_model, pos_raw_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=0.01, lambda_E=1.5, lambda_F=1.5
        )
        
        # Evaluate
        E_pred_train_g = np.array(graph_model.vec_circuit(pos_raw_train, graph_model.params))
        E_pred_test_g = np.array(graph_model.vec_circuit(pos_raw_test, graph_model.params))
        
        def energy_single_g(pos, params):
            return graph_model.circuit(pos, params)
        def force_single_g(pos, params):
            return -jax.grad(energy_single_g, argnums=0)(pos, params)
        vec_force_g = jax.vmap(force_single_g, (0, None), 0)
        
        F_pred_train_g = np.array(vec_force_g(pos_raw_train, graph_model.params))[:, 1, 2]
        F_pred_test_g = np.array(vec_force_g(pos_raw_test, graph_model.params))[:, 1, 2]
        
        graph_metrics, graph_preds = evaluate_fold(
            E_pred_train_g, E_pred_test_g, F_pred_train_g, F_pred_test_g,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["graph_permutation_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": graph_metrics,
            "history": graph_history,
        })
        print(f"    Energy R²: {graph_metrics['E_r2']:.4f}, Force R²: {graph_metrics['F_r2']:.4f}")
        
        # ==================== 4. Classical Equivariant ====================
        print(f"\n  [4/4] Classical Rotationally Equivariant NN...")
        classical_model = ClassicalRotationallyEquivariantNN(hidden_dims=[128, 128, 64], seed=fold_seed)
        
        classical_history = train_classical_equivariant(
            classical_model, pos_raw_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=3e-3, wE=1.0, wF_max=2.0, warmup_frac=0.3
        )
        
        # Evaluate
        E_pred_train_c = np.array(classical_model.vec_energy(pos_raw_train, classical_model.params))
        E_pred_test_c = np.array(classical_model.vec_energy(pos_raw_test, classical_model.params))
        F_pred_train_c = np.array(classical_model.vec_force(pos_raw_train, classical_model.params))[:, 1, 2]
        F_pred_test_c = np.array(classical_model.vec_force(pos_raw_test, classical_model.params))[:, 1, 2]
        
        classical_metrics, classical_preds = evaluate_fold(
            E_pred_train_c, E_pred_test_c, F_pred_train_c, F_pred_test_c,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["classical_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": classical_metrics,
            "history": classical_history,
        })
        print(f"    Energy R²: {classical_metrics['E_r2']:.4f}, Force R²: {classical_metrics['F_r2']:.4f}")
    
    # Compute summary statistics
    print(f"\n{'='*80}")
    print("COMPUTING SUMMARY STATISTICS")
    print(f"{'='*80}")
    
    metrics_keys = ["E_r2", "E_mae_Ha", "E_rmse_Ha", "F_r2", "F_mae", "F_rmse"]
    
    for method in method_names:
        folds_data = results[method]["folds"]
        summary = {}
        
        for metric in metrics_keys:
            values = [fold["metrics"][metric] for fold in folds_data]
            summary[metric] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        
        # Coefficient of variation
        for metric in ["E_r2", "F_r2"]:
            values = summary[metric]["values"]
            mean = summary[metric]["mean"]
            if mean != 0:
                summary[f"{metric}_cv"] = float(np.std(values) / abs(mean))
            else:
                summary[f"{metric}_cv"] = float('inf')
        
        results[method]["summary"] = summary
    
    # Print summary table
    print("\n" + "="*100)
    print("K-FOLD CROSS-VALIDATION SUMMARY")
    print("="*100)
    
    print(f"\n{'Method':<35} {'E_R² (mean±std)':<20} {'F_R² (mean±std)':<20} {'E_CV':<10} {'F_CV':<10}")
    print("-"*100)
    
    method_labels = {
        "rotationally_equivariant": "Rot. Equiv. QML",
        "non_equivariant": "Non-Equiv. QML",
        "graph_permutation_equivariant": "Graph Perm. QML",
        "classical_equivariant": "Classical Equiv. NN",
    }
    
    for method in method_names:
        summary = results[method]["summary"]
        e_r2 = summary["E_r2"]
        f_r2 = summary["F_r2"]
        e_cv = summary.get("E_r2_cv", 0)
        f_cv = summary.get("F_r2_cv", 0)
        
        print(f"{method_labels[method]:<35} "
              f"{e_r2['mean']:.4f}±{e_r2['std']:.4f}       "
              f"{f_r2['mean']:.4f}±{f_r2['std']:.4f}       "
              f"{e_cv:.4f}     {f_cv:.4f}")
    
    print("="*100)
    print("CV = Coefficient of Variation (lower = more consistent across folds)")
    
    # Save results
    results_path = os.path.join(output_dir, "kfold_results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save NPZ
    npz_data = {}
    for method in method_names:
        for metric in metrics_keys:
            key = f"{method}_{metric}"
            npz_data[key] = np.array(results[method]["summary"][metric]["values"])
    
    npz_path = os.path.join(output_dir, "kfold_metrics.npz")
    np.savez(npz_path, **npz_data)
    print(f"Metrics saved to: {npz_path}")
    
    return results


def main(k_folds=5, n_epochs=200, output_dir="kfold_results_lih", 
         data_dir="eqnn_force_field_data_LiH", seed=42):
    """
    Main function for k-fold cross-validation on LiH.
    
    Args:
        k_folds: Number of folds (default: 5)
        n_epochs: Epochs per fold (default: 200)
        output_dir: Output directory
        data_dir: Data directory
        seed: Random seed
    
    Returns:
        Results dictionary
    """
    return run_kfold_cv(k_folds, n_epochs, data_dir, output_dir, seed)


if __name__ == "__main__":
    import sys
    
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter. Call main() with parameters:")
        print("  results = main(k_folds=5, n_epochs=200, output_dir='kfold_results_lih')")
    else:
        parser = argparse.ArgumentParser(description="LiH K-Fold Cross-Validation Comparison")
        parser.add_argument("--k_folds", type=int, default=5, help="Number of folds")
        parser.add_argument("--n_epochs", type=int, default=200, help="Epochs per fold")
        parser.add_argument("--output_dir", type=str, default="kfold_results_lih", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_LiH", help="Data directory")
        parser.add_argument("--seed", type=int, default=42, help="Random seed")
        
        args = parser.parse_args()
        
        main(
            k_folds=args.k_folds,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir,
            seed=args.seed
        )



Running in Jupyter. Call main() with parameters:
  results = main(k_folds=5, n_epochs=200, output_dir='kfold_results_lih')


In [2]:
results = main(k_folds=5, n_epochs=200, output_dir='kfold_results_lih')




LiH K-FOLD CROSS-VALIDATION COMPARISON
K-Folds: 5
Epochs per fold: 200
Data directory: eqnn_force_field_data_LiH
Output directory: kfold_results_lih

Loading LiH dataset...
  Loaded 2400 samples
  Positions shape: (2400, 2, 3)
  Energy shape: (2400,)
  Forces shape: (2400, 2, 3)

FOLD 1/5
  Train samples: 1920, Test samples: 480

  [1/4] Rotationally Equivariant QML...
    Energy R²: 0.9963, Force R²: 0.9911

  [2/4] Non-Equivariant QML...
    Energy R²: 0.9976, Force R²: 0.8255

  [3/4] Graph Permutation Equivariant QML...
    Energy R²: 0.9965, Force R²: 0.9612

  [4/4] Classical Rotationally Equivariant NN...
    Energy R²: 0.9980, Force R²: 0.9955

FOLD 2/5
  Train samples: 1920, Test samples: 480

  [1/4] Rotationally Equivariant QML...
    Energy R²: 0.9969, Force R²: 0.9975

  [2/4] Non-Equivariant QML...
    Energy R²: 0.9974, Force R²: 0.7992

  [3/4] Graph Permutation Equivariant QML...
    Energy R²: 0.9967, Force R²: 0.9735

  [4/4] Classical Rotationally Equivariant NN...