In [1]:
"""
NH₃ QML K-Fold Cross-Validation Comparison
============================================

Compares generalizability of four approaches using k-fold cross-validation:
1. Rotationally Equivariant QML - EQNN-style with SO(3) equivariant encoding
2. Non-Equivariant QML - Simple QNN with basic rotations (no symmetry)
3. Graph Permutation Equivariant QML - Graph-based encoding with permutation symmetry
4. Classical Rotationally Equivariant NN - Classical MLP on pairwise distances (E(3) invariant)

This script evaluates model generalizability by:
- Splitting data into k folds
- Training on k-1 folds, testing on held-out fold
- Repeating for each fold
- Computing variance across folds (measures sensitivity to data splits)

Usage (command line):
    python run_kfold_comparison_nh3.py --k_folds 5 --n_epochs 300 --output_dir kfold_results

Usage (Jupyter):
    from run_kfold_comparison_nh3 import main
    results = main(k_folds=5, n_epochs=300, output_dir='kfold_results')
"""

import pennylane as qml
import numpy as np
import json
import os
import argparse
from datetime import datetime
from sklearn.model_selection import KFold

import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

from jax import numpy as jnp
from jax.example_libraries import optimizers
from sklearn.preprocessing import MinMaxScaler
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')


# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def huber(residual, delta=1.0):
    """Elementwise Huber loss for robust force training."""
    abs_r = jnp.abs(residual)
    quad = 0.5 * residual**2
    lin = delta * (abs_r - 0.5 * delta)
    return jnp.where(abs_r <= delta, quad, lin)


# =============================================================================
# 1. ROTATIONALLY EQUIVARIANT QML MODEL (SO(3))
# =============================================================================

class RotationallyEquivariantQML:
    """Rotationally Equivariant QML for NH₃."""
    
    def __init__(self, depth=6, rep=2, active_atoms=3, seed=42):
        self.depth = depth
        self.rep = rep
        self.active_atoms = active_atoms
        self.n_qubits = active_atoms * rep
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=self.n_qubits)
        self.observable = (
            qml.PauliX(0) @ qml.PauliX(1)
            + qml.PauliY(0) @ qml.PauliY(1)
            + qml.PauliZ(0) @ qml.PauliZ(1)
        )
        
        self._build_circuit()
        self._init_params()
    
    def _singlet(self, wires):
        w0, w1 = wires
        qml.Hadamard(wires=w0)
        qml.PauliZ(wires=w0)
        qml.PauliX(wires=w1)
        qml.CNOT(wires=[w0, w1])
    
    def _equivariant_encoding(self, alpha, vec3, wire):
        r = jnp.array(vec3, dtype=jnp.float64)
        norm = jnp.linalg.norm(r) + 1e-12
        n = r / norm
        theta = alpha * norm
        qml.Rot(theta * n[0], theta * n[1], theta * n[2], wires=wire)
    
    def _pair_layer(self, weight, wires):
        qml.IsingXX(weight, wires=wires)
        qml.IsingYY(weight, wires=wires)
        qml.IsingZZ(weight, wires=wires)
    
    def _build_circuit(self):
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(coords, params):
            weights = params["weights"]
            alphas = params["alphas"]
            
            for i in range(0, self.n_qubits - 1, 2):
                self._singlet([i, i + 1])
            
            for i in range(self.n_qubits):
                self._equivariant_encoding(alphas[i, 0], coords[i % self.active_atoms], i)
            
            for d in range(self.depth):
                qml.Barrier()
                for i in range(0, self.n_qubits - 1, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                for i in range(1, self.n_qubits, 2):
                    self._pair_layer(weights[i, d], [i, (i + 1) % self.n_qubits])
                for i in range(self.n_qubits):
                    self._equivariant_encoding(alphas[i, d + 1], coords[i % self.active_atoms], i)
            
            return qml.expval(self.observable)
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        np.random.seed(self.seed)
        weights0 = np.zeros((self.n_qubits, self.depth), dtype=np.float64)
        weights0[0] = np.random.uniform(0.0, np.pi, size=(self.depth,))
        alphas0 = np.random.uniform(0.5, 1.5, size=(self.n_qubits, self.depth + 1))
        
        self.params = {
            "weights": jnp.array(weights0),
            "alphas": jnp.array(alphas0),
            "head_scale": jnp.array(1.0),
            "head_bias": jnp.array(0.0),
        }
    
    def energy(self, coords, params):
        raw = self.circuit(coords, params)
        return params["head_scale"] * raw + params["head_bias"]
    
    def force(self, coords, params):
        grad_fn = jax.grad(self.energy, argnums=0)
        return -grad_fn(coords, params)
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 2. NON-EQUIVARIANT QML MODEL
# =============================================================================

class NonEquivariantQML:
    """Non-Equivariant QML for NH₃."""
    
    def __init__(self, n_qubits=6, depth=4, seed=42):
        self.n_qubits = n_qubits
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
        self._init_params()
    
    def _build_circuit(self):
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(distances, params):
            weights = params["weights"]
            
            for i in range(self.n_qubits):
                qml.Hadamard(wires=i)
            
            for i, d in enumerate(distances[:3]):
                qml.RY(d * np.pi, wires=i)
                qml.RY(d * np.pi, wires=i + 3)
            
            for layer in range(self.depth):
                for i in range(self.n_qubits):
                    qml.RX(weights[layer, i, 0], wires=i)
                    qml.RY(weights[layer, i, 1], wires=i)
                    qml.RZ(weights[layer, i, 2], wires=i)
                for i in range(self.n_qubits - 1):
                    qml.CNOT(wires=[i, i + 1])
                qml.CNOT(wires=[self.n_qubits - 1, 0])
            
            return qml.expval(qml.PauliZ(0))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        np.random.seed(self.seed)
        weights = np.random.uniform(-np.pi, np.pi, (self.depth, self.n_qubits, 3))
        self.params = {
            "weights": jnp.array(weights),
            "head_scale": jnp.array(1.0),
            "head_bias": jnp.array(0.0),
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 3. GRAPH PERMUTATION EQUIVARIANT QML MODEL
# =============================================================================

class GraphPermutationEquivariantQML:
    """
    Graph Permutation Equivariant QML for NH₃.
    Encodes N-H bonds as graph edges with geometric features.
    Uses permutation-symmetric encoding with learnable parameters.
    """
    
    def __init__(self, n_qubits=6, depth=4, seed=42):
        self.n_qubits = n_qubits  # 3 bonds × 2 qubits per bond
        self.depth = depth
        self.seed = seed
        
        self.dev = qml.device("default.qubit", wires=n_qubits)
        self._build_circuit()
        self._init_params()
    
    def _build_circuit(self):
        """Build the graph-based quantum circuit."""
        
        @qml.qnode(self.dev, interface="jax", diff_method="backprop")
        def circuit(positions, params):
            """
            positions: (4, 3) - [N, H1, H2, H3] coordinates
            params: {"weights": (depth, n_qubits, 3)}
            """
            weights = params["weights"]
            
            # N at index 0, H atoms at indices 1, 2, 3
            N_pos = positions[0]
            H_positions = positions[1:]  # (3, 3)
            
            # Compute bond vectors and distances
            bonds = H_positions - N_pos[None, :]  # (3, 3)
            distances = jnp.linalg.norm(bonds, axis=1)  # (3,)
            
            # Compute angles between bonds
            def compute_angle(v1, v2):
                cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + 1e-12)
                return jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0))
            
            angle_01 = compute_angle(bonds[0], bonds[1])
            angle_02 = compute_angle(bonds[0], bonds[2])
            angle_12 = compute_angle(bonds[1], bonds[2])
            
            # Initialize qubits
            for i in range(self.n_qubits):
                qml.RY(0.5, wires=i)
            
            # Apply layers
            for layer in range(self.depth):
                # Encode bond distances (symmetric across bond pairs)
                qml.RY(weights[layer, 0, 0] * distances[0], wires=0)
                qml.RY(weights[layer, 1, 0] * distances[0], wires=1)
                qml.RY(weights[layer, 2, 0] * distances[1], wires=2)
                qml.RY(weights[layer, 3, 0] * distances[1], wires=3)
                qml.RY(weights[layer, 4, 0] * distances[2], wires=4)
                qml.RY(weights[layer, 5, 0] * distances[2], wires=5)
                
                # Entangle within bonds
                qml.CNOT(wires=[0, 1])
                qml.CNOT(wires=[2, 3])
                qml.CNOT(wires=[4, 5])
                
                # Encode angular information
                qml.RZ(weights[layer, 0, 1] * angle_01, wires=0)
                qml.RZ(weights[layer, 2, 1] * angle_01, wires=2)
                qml.RZ(weights[layer, 0, 2] * angle_02, wires=0)
                qml.RZ(weights[layer, 4, 2] * angle_02, wires=4)
                qml.RZ(weights[layer, 2, 2] * angle_12, wires=2)
                qml.RZ(weights[layer, 4, 2] * angle_12, wires=4)
                
                # Cross-bond entanglement
                qml.CNOT(wires=[1, 2])
                qml.CNOT(wires=[3, 4])
                qml.CNOT(wires=[5, 0])
                
                # Additional rotations
                for i in range(self.n_qubits):
                    qml.RZ(weights[layer, i, 1], wires=i)
                    qml.RY(weights[layer, i, 2], wires=i)
            
            return qml.expval(qml.sum(*(qml.PauliZ(i) for i in range(self.n_qubits))))
        
        self.circuit = circuit
        self.vec_circuit = jax.vmap(circuit, in_axes=(0, None), out_axes=0)
    
    def _init_params(self):
        """Initialize parameters."""
        np.random.seed(self.seed)
        self.params = {
            "weights": jnp.array(np.random.normal(0, 0.1, (self.depth, self.n_qubits, 3)))
        }
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# 4. CLASSICAL ROTATIONALLY EQUIVARIANT NN (IMPROVED)
# =============================================================================

class ClassicalRotationallyEquivariantNN:
    """Improved Classical Rotationally Equivariant Neural Network for NH₃."""
    
    def __init__(self, hidden_dims=[128, 128, 64], seed=42):
        self.hidden_dims = hidden_dims
        self.seed = seed
        self._init_params()
        self._create_model()
    
    def _init_params(self):
        np.random.seed(self.seed)
        n_features = 24
        layer_sizes = [n_features] + self.hidden_dims + [1]
        
        params = {
            "weights": [],
            "biases": [],
            "skip_weights": [],
            "output_scale": jnp.array(1.0),
            "output_bias": jnp.array(0.0),
        }
        
        for i in range(len(layer_sizes) - 1):
            fan_in = layer_sizes[i]
            fan_out = layer_sizes[i + 1]
            std = np.sqrt(2.0 / fan_in)
            W = np.random.normal(0, std, (fan_in, fan_out))
            b = np.zeros(fan_out)
            params["weights"].append(jnp.array(W))
            params["biases"].append(jnp.array(b))
        
        skip_std = np.sqrt(2.0 / n_features)
        params["skip_weights"] = jnp.array(np.random.normal(0, skip_std, (n_features, self.hidden_dims[-1])))
        params["rbf_coeffs"] = jnp.array(np.random.normal(0, 0.1, (6, 8)))
        params["rbf_output"] = jnp.array(np.random.normal(0, 0.1, (6,)))
        
        self.params = params
    
    def _create_model(self):
        def silu(x):
            return x * jax.nn.sigmoid(x)
        
        def compute_features(positions):
            eps = 1e-8
            d_NH1 = jnp.linalg.norm(positions[1] - positions[0]) + eps
            d_NH2 = jnp.linalg.norm(positions[2] - positions[0]) + eps
            d_NH3 = jnp.linalg.norm(positions[3] - positions[0]) + eps
            d_H1H2 = jnp.linalg.norm(positions[2] - positions[1]) + eps
            d_H1H3 = jnp.linalg.norm(positions[3] - positions[1]) + eps
            d_H2H3 = jnp.linalg.norm(positions[3] - positions[2]) + eps
            
            distances = jnp.array([d_NH1, d_NH2, d_NH3, d_H1H2, d_H1H3, d_H2H3])
            dist_norm = distances / 1.5
            inv_dist = 1.0 / distances
            inv_dist_norm = inv_dist / 1.0
            
            r_eq = jnp.array([1.01, 1.01, 1.01, 1.63, 1.63, 1.63])
            alpha = 2.0
            morse = jnp.exp(-alpha * (distances - r_eq))
            
            def compute_angle(p1, p2, p_center):
                v1 = p1 - p_center
                v2 = p2 - p_center
                cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2) + eps)
                return jnp.arccos(jnp.clip(cos_angle, -1.0 + eps, 1.0 - eps))
            
            angle_H1NH2 = compute_angle(positions[1], positions[2], positions[0])
            angle_H1NH3 = compute_angle(positions[1], positions[3], positions[0])
            angle_H2NH3 = compute_angle(positions[2], positions[3], positions[0])
            
            angles = jnp.array([angle_H1NH2, angle_H1NH3, angle_H2NH3])
            angles_norm = angles / jnp.pi
            cos_angles = jnp.cos(angles)
            
            features = jnp.concatenate([dist_norm, inv_dist_norm, morse, angles_norm, cos_angles])
            return features, distances
        
        def rbf_energy(distances, params):
            centers = jnp.linspace(0.8, 2.5, 8)
            width = 0.2
            rbf = jnp.exp(-((distances[:, None] - centers[None, :]) ** 2) / (2 * width ** 2))
            pair_energies = jnp.sum(rbf * params["rbf_coeffs"], axis=1)
            return jnp.dot(pair_energies, params["rbf_output"])
        
        def mlp_forward(features, params):
            weights = params["weights"]
            biases = params["biases"]
            h = features
            
            for i in range(len(weights) - 1):
                h = jnp.dot(h, weights[i]) + biases[i]
                h = silu(h)
                if i == len(weights) - 2:
                    skip = jnp.dot(features, params["skip_weights"])
                    h = h + 0.1 * skip
            
            h = jnp.dot(h, weights[-1]) + biases[-1]
            return h.squeeze(-1)
        
        def energy_from_positions(positions, params):
            features, distances = compute_features(positions)
            mlp_energy = mlp_forward(features, params)
            rbf_contrib = rbf_energy(distances, params)
            total_energy = params["output_scale"] * mlp_energy + rbf_contrib + params["output_bias"]
            return total_energy
        
        def force_from_positions(positions, params):
            grad_fn = jax.grad(energy_from_positions, argnums=0)
            return -grad_fn(positions, params)
        
        self.energy_fn = energy_from_positions
        self.force_fn = force_from_positions
        self.vec_energy = jax.vmap(energy_from_positions, (0, None), 0)
        self.vec_force = jax.vmap(force_from_positions, (0, None), 0)
    
    def get_params(self):
        return self.params
    
    def set_params(self, params):
        self.params = params


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_rotationally_equivariant(model, pos_H, E_train, F_train, n_epochs=300, lr=3e-3,
                                    wE=1.0, wF_max=5.0, warmup_frac=0.4):
    """Train Rotationally Equivariant QML with force warmup curriculum."""
    warmup_epochs = int(n_epochs * warmup_frac)
    
    def raw_energy(coords, params):
        return model.circuit(coords, params)
    
    vec_raw_energy = jax.vmap(raw_energy, (0, None), 0)
    
    def vec_force_fn(coords_batch, params):
        def single_force(coords):
            grad_fn = jax.grad(raw_energy, argnums=0)
            return -grad_fn(coords, params)
        return jax.vmap(single_force)(coords_batch)
    
    @jax.jit
    def loss_fn(params, coords, E_target, F_target, wF):
        E_raw = vec_raw_energy(coords, params)
        E_pred = params["head_scale"] * E_raw + params["head_bias"]
        L_E = jnp.mean((E_pred - E_target) ** 2)
        
        F_raw = vec_force_fn(coords, params)
        F_pred = params["head_scale"] * F_raw
        L_F = jnp.mean((F_pred - F_target) ** 2)
        
        L_E = jnp.where(jnp.isnan(L_E), 1.0, L_E)
        L_F = jnp.where(jnp.isnan(L_F), 1.0, L_F)
        
        return wE * L_E + wF * L_F, (L_E, L_F)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        if epoch < warmup_epochs:
            wF = wF_max * (epoch / warmup_epochs)
        else:
            wF = wF_max
        
        (loss, (L_E, L_F)), grads = jax.value_and_grad(loss_fn, has_aux=True)(
            get_params(opt_state), pos_H, E_train, F_train, wF
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(L_E))
            history["test_F_loss"].append(float(L_F))
    
    model.set_params(get_params(opt_state))
    return history


def train_non_equivariant(model, distances, E_train, F_train_flat, n_epochs=300,
                          lr=1e-3, lambda_E=2.0, lambda_F=1.0):
    """Train Non-Equivariant QML with combined loss."""
    
    def energy_fn(dists, params):
        raw = model.circuit(dists, params)
        return params["head_scale"] * raw + params["head_bias"]
    
    vec_energy = jax.vmap(energy_fn, (0, None), 0)
    
    def force_fn(dists, params):
        return -jax.grad(energy_fn, argnums=0)(dists, params)
    
    vec_force = jax.vmap(force_fn, (0, None), 0)
    
    @jax.jit
    def combined_loss(params, dists, E_target, F_target):
        E_pred = vec_energy(dists, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred = vec_force(dists, params)
        F_loss = jnp.mean((F_pred - F_target) ** 2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return lambda_E * E_loss + lambda_F * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for epoch in range(n_epochs):
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), distances, E_train, F_train_flat
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_graph_permutation(model, positions, E_train, F_train_H, n_epochs=300,
                            lr=0.01, n_epochs_energy=None, n_epochs_combined=None):
    """
    Train Graph Permutation Equivariant QML with two-phase training.
    Uses raw positions and computes forces on H atoms.
    
    Args:
        model: GraphPermutationEquivariantQML instance
        positions: (N, 4, 3) - atomic positions [N, H1, H2, H3]
        E_train: (N,) - energies
        F_train_H: (N, 3, 3) - forces on H atoms only
        n_epochs: total epochs (split evenly if phase epochs not specified)
        lr: learning rate
    """
    if n_epochs_energy is None:
        n_epochs_energy = n_epochs // 2
    if n_epochs_combined is None:
        n_epochs_combined = n_epochs - n_epochs_energy
    
    # Phase 1: Energy only
    @jax.jit
    def energy_loss(params, positions, E_target):
        E_pred = model.vec_circuit(positions, params)
        return jnp.mean((E_pred - E_target)**2)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    for step in range(n_epochs_energy):
        params = get_params(opt_state)
        loss, grads = jax.value_and_grad(energy_loss)(params, positions, E_train)
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
        
        if (step + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(step + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(loss))
            history["test_F_loss"].append(0.0)
    
    # Phase 2: Combined energy + forces
    trained_params = get_params(opt_state)
    
    def force_single(coords, params):
        grad_fn = jax.grad(lambda c, p: model.circuit(c, p), argnums=0)
        return -grad_fn(coords, params)
    
    vec_force = jax.vmap(force_single, in_axes=(0, None), out_axes=0)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target):
        E_pred = model.vec_circuit(positions, params)
        E_loss = jnp.mean((E_pred - E_target)**2)
        
        F_pred_full = vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]  # Forces on H atoms only
        F_loss = jnp.mean((F_pred_H - F_target)**2)
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return 2.0 * E_loss + 1.0 * F_loss, (E_loss, F_loss)
    
    opt_state = opt_init(trained_params)
    
    for step in range(n_epochs_combined):
        params = get_params(opt_state)
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            params, positions, E_train, F_train_H
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads)))
        if grad_norm > 10.0:
            grads = jax.tree.map(lambda g: g * (10.0 / grad_norm), grads)
        
        opt_state = opt_update(step, grads, opt_state)
        
        if (step + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(n_epochs_energy + step + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


def train_classical_equivariant(model, positions, E_train, F_train, n_epochs=300,
                                 lr=0.003, lambda_E=1.0, lambda_F=2.0, warmup_frac=0.3):
    """Train Classical Equivariant NN with two-phase training."""
    warmup_epochs = int(n_epochs * warmup_frac)
    
    @jax.jit
    def energy_loss(params, positions, E_target):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        return jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
    
    @jax.jit
    def combined_loss(params, positions, E_target, F_target, wF):
        E_pred = model.vec_energy(positions, params)
        E_loss = jnp.mean((E_pred - E_target) ** 2)
        
        F_pred_full = model.vec_force(positions, params)
        F_pred_H = F_pred_full[:, 1:, :]
        F_residual = F_pred_H - F_target
        F_loss = jnp.mean(huber(F_residual, delta=0.5))
        
        E_loss = jnp.where(jnp.isnan(E_loss), 1.0, E_loss)
        F_loss = jnp.where(jnp.isnan(F_loss), 1.0, F_loss)
        
        return lambda_E * E_loss + wF * F_loss, (E_loss, F_loss)
    
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(model.params)
    
    history = {"epoch": [], "train_loss": [], "test_E_loss": [], "test_F_loss": []}
    
    # Phase 1: Energy warmup
    for epoch in range(warmup_epochs):
        loss, grads = jax.value_and_grad(energy_loss)(
            get_params(opt_state), positions, E_train
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(loss))
            history["test_F_loss"].append(0.0)
    
    # Phase 2: Combined with force ramp
    for epoch in range(warmup_epochs, n_epochs):
        progress = (epoch - warmup_epochs) / max(1, n_epochs - warmup_epochs)
        wF = lambda_F * min(1.0, progress * 2)
        
        (loss, (E_loss, F_loss)), grads = jax.value_and_grad(combined_loss, has_aux=True)(
            get_params(opt_state), positions, E_train, F_train, wF
        )
        
        grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads) if g is not None))
        if grad_norm > 5.0:
            grads = jax.tree.map(lambda g: g * (5.0 / grad_norm) if g is not None else g, grads)
        
        opt_state = opt_update(epoch, grads, opt_state)
        
        if (epoch + 1) % max(1, n_epochs // 20) == 0:
            history["epoch"].append(epoch + 1)
            history["train_loss"].append(float(loss))
            history["test_E_loss"].append(float(E_loss))
            history["test_F_loss"].append(float(F_loss))
    
    model.set_params(get_params(opt_state))
    return history


# =============================================================================
# EVALUATION FUNCTION
# =============================================================================

def evaluate_fold(E_pred_train, E_pred_test, F_pred_train, F_pred_test,
                  E_train_true, E_test_true, F_train_true, F_test_true,
                  energy_scaler, force_scaler):
    """
    Evaluate predictions on a single fold with post-correction.
    """
    # Post-correction for energy (quadratic fit on training data)
    def corr_E(E, a, b, c):
        return a * E**2 + b * E + c
    
    try:
        popt_E, _ = curve_fit(corr_E, E_pred_train, E_train_true, maxfev=5000)
        E_pred_test_corr = corr_E(E_pred_test, *popt_E)
    except:
        E_pred_test_corr = E_pred_test
    
    # Post-correction for forces (linear fit on training data)
    try:
        F_pred_train_flat = F_pred_train.flatten().reshape(-1, 1)
        F_train_true_flat = F_train_true.flatten().reshape(-1, 1)
        lr_model = LinearRegression()
        lr_model.fit(F_pred_train_flat, F_train_true_flat)
        F_pred_test_corr = lr_model.predict(F_pred_test.flatten().reshape(-1, 1)).reshape(F_pred_test.shape)
    except:
        F_pred_test_corr = F_pred_test
    
    # Inverse transform to original units
    E_pred_final = energy_scaler.inverse_transform(E_pred_test_corr.reshape(-1, 1)).flatten()
    E_true_final = energy_scaler.inverse_transform(E_test_true.reshape(-1, 1)).flatten()
    
    F_pred_final = force_scaler.inverse_transform(F_pred_test_corr.flatten().reshape(-1, 1)).reshape(F_pred_test.shape)
    F_true_final = force_scaler.inverse_transform(F_test_true.flatten().reshape(-1, 1)).reshape(F_test_true.shape)
    
    # Compute metrics
    E_mae = np.mean(np.abs(E_pred_final - E_true_final))
    E_rmse = np.sqrt(np.mean((E_pred_final - E_true_final)**2))
    ss_res_E = np.sum((E_pred_final - E_true_final)**2)
    ss_tot_E = np.sum((E_true_final - E_true_final.mean())**2)
    E_r2 = 1 - ss_res_E / ss_tot_E if ss_tot_E > 0 else 0
    
    F_mae = np.mean(np.abs(F_pred_final - F_true_final))
    F_rmse = np.sqrt(np.mean((F_pred_final - F_true_final)**2))
    ss_res_F = np.sum((F_pred_final - F_true_final)**2)
    ss_tot_F = np.sum((F_true_final - F_true_final.mean())**2)
    F_r2 = 1 - ss_res_F / ss_tot_F if ss_tot_F > 0 else 0
    
    metrics = {
        "E_r2": float(E_r2),
        "E_mae_Ha": float(E_mae),
        "E_rmse_Ha": float(E_rmse),
        "E_mae_eV": float(E_mae * 27.2114),
        "F_r2": float(F_r2),
        "F_mae": float(F_mae),
        "F_rmse": float(F_rmse),
    }
    
    predictions = {
        "E_pred": E_pred_final.tolist(),
        "E_true": E_true_final.tolist(),
        "F_pred": F_pred_final.tolist(),
        "F_true": F_true_final.tolist(),
    }
    
    return metrics, predictions


# =============================================================================
# K-FOLD CROSS-VALIDATION FUNCTION
# =============================================================================

def run_kfold_cv(k_folds, n_epochs, data_dir, output_dir, seed=42):
    """
    Run k-fold cross-validation for all four methods.
    
    Args:
        k_folds: Number of folds for cross-validation
        n_epochs: Number of training epochs per fold
        data_dir: Directory containing NH₃ dataset
        output_dir: Directory to save results
        seed: Random seed for reproducibility
    
    Returns:
        Dictionary with all results
    """
    print(f"\n{'='*80}")
    print(f"NH₃ K-FOLD CROSS-VALIDATION COMPARISON")
    print(f"{'='*80}")
    print(f"K-Folds: {k_folds}")
    print(f"Epochs per fold: {n_epochs}")
    print(f"Data directory: {data_dir}")
    print(f"Output directory: {output_dir}")
    print(f"{'='*80}\n")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    print("Loading NH₃ dataset...")
    positions = np.load(os.path.join(data_dir, "Positions.npy"))
    energy = np.load(os.path.join(data_dir, "Energy.npy"))
    forces = np.load(os.path.join(data_dir, "Forces.npy"))
    
    N_samples = len(energy)
    print(f"  Loaded {N_samples} samples")
    print(f"  Positions shape: {positions.shape}")
    print(f"  Energy shape: {energy.shape}")
    print(f"  Forces shape: {forces.shape}")
    
    # Prepare scalers (fit on entire dataset for consistency)
    energy_scaler = MinMaxScaler((-1, 1))
    energy_scaled = energy_scaler.fit_transform(energy.reshape(-1, 1)).flatten()
    
    forces_H = forces[:, 1:, :]  # H atoms only
    force_scaler = MinMaxScaler((-1, 1))
    forces_flat = forces_H.reshape(-1, 1)
    forces_scaled = force_scaler.fit_transform(forces_flat).reshape(forces_H.shape)
    
    # Prepare different feature representations
    positions_H = positions[:, 1:, :]  # H positions relative to N
    
    # Non-equivariant: N-H distances
    distances_NH = np.array([
        [np.linalg.norm(positions[i, 1] - positions[i, 0]),
         np.linalg.norm(positions[i, 2] - positions[i, 0]),
         np.linalg.norm(positions[i, 3] - positions[i, 0])]
        for i in range(N_samples)
    ])
    
    # Graph features: distances + angles
    def compute_graph_features(pos):
        d1 = np.linalg.norm(pos[1] - pos[0])
        d2 = np.linalg.norm(pos[2] - pos[0])
        d3 = np.linalg.norm(pos[3] - pos[0])
        
        def angle_at_center(p1, p2, pc):
            v1, v2 = p1 - pc, p2 - pc
            cos_a = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8)
            return np.arccos(np.clip(cos_a, -1, 1))
        
        a1 = angle_at_center(pos[1], pos[2], pos[0])
        a2 = angle_at_center(pos[1], pos[3], pos[0])
        a3 = angle_at_center(pos[2], pos[3], pos[0])
        
        return np.array([d1, d2, d3, a1, a2, a3])
    
    graph_features = np.array([compute_graph_features(positions[i]) for i in range(N_samples)])
    
    # Force gradients for non-equivariant and graph models
    forces_grad_NH = np.zeros((N_samples, 3))
    for i in range(N_samples):
        for j in range(3):
            F_atom = forces_H[i, j]
            r_vec = positions[i, j+1] - positions[i, 0]
            r_norm = np.linalg.norm(r_vec)
            if r_norm > 1e-8:
                forces_grad_NH[i, j] = -np.dot(F_atom, r_vec) / r_norm
    
    forces_grad_NH_scaled = force_scaler.fit_transform(forces_grad_NH.reshape(-1, 1)).reshape(forces_grad_NH.shape)
    
    # Graph force gradients
    forces_grad_graph = np.zeros((N_samples, 6))
    forces_grad_graph[:, :3] = forces_grad_NH
    forces_grad_graph_scaled = force_scaler.fit_transform(forces_grad_graph.reshape(-1, 1)).reshape(forces_grad_graph.shape)
    
    # Initialize k-fold
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=seed)
    
    # Results storage
    results = {
        "config": {
            "k_folds": k_folds,
            "n_epochs": n_epochs,
            "n_samples": N_samples,
            "seed": seed,
            "timestamp": datetime.now().isoformat(),
        },
        "rotationally_equivariant": {"folds": [], "summary": {}},
        "non_equivariant": {"folds": [], "summary": {}},
        "graph_permutation_equivariant": {"folds": [], "summary": {}},
        "classical_equivariant": {"folds": [], "summary": {}},
    }
    
    method_names = [
        "rotationally_equivariant",
        "non_equivariant", 
        "graph_permutation_equivariant",
        "classical_equivariant"
    ]
    
    # Run k-fold CV
    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(positions)):
        print(f"\n{'='*60}")
        print(f"FOLD {fold_idx + 1}/{k_folds}")
        print(f"{'='*60}")
        print(f"  Train samples: {len(train_idx)}, Test samples: {len(test_idx)}")
        
        fold_seed = seed + fold_idx * 100
        
        # Prepare fold data
        E_train = energy_scaled[train_idx]
        E_test = energy_scaled[test_idx]
        F_train = forces_scaled[train_idx]
        F_test = forces_scaled[test_idx]
        
        pos_H_train = jnp.array(positions_H[train_idx])
        pos_H_test = jnp.array(positions_H[test_idx])
        
        pos_full_train = jnp.array(positions[train_idx])
        pos_full_test = jnp.array(positions[test_idx])
        
        dist_train = jnp.array(distances_NH[train_idx])
        dist_test = jnp.array(distances_NH[test_idx])
        
        graph_train = jnp.array(graph_features[train_idx])
        graph_test = jnp.array(graph_features[test_idx])
        
        F_grad_train = jnp.array(forces_grad_NH_scaled[train_idx])
        F_grad_test = jnp.array(forces_grad_NH_scaled[test_idx])
        
        F_grad_graph_train = jnp.array(forces_grad_graph_scaled[train_idx])
        F_grad_graph_test = jnp.array(forces_grad_graph_scaled[test_idx])
        
        # ==================== 1. Rotationally Equivariant ====================
        print(f"\n  [1/4] Rotationally Equivariant QML...")
        rot_model = RotationallyEquivariantQML(depth=6, rep=2, active_atoms=3, seed=fold_seed)
        
        rot_history = train_rotationally_equivariant(
            rot_model, pos_H_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=3e-3, wE=1.0, wF_max=5.0, warmup_frac=0.4
        )
        
        # Evaluate
        E_pred_train = np.array([rot_model.energy(pos_H_train[i], rot_model.params) for i in range(len(train_idx))])
        E_pred_test = np.array([rot_model.energy(pos_H_test[i], rot_model.params) for i in range(len(test_idx))])
        F_pred_train = np.array([rot_model.force(pos_H_train[i], rot_model.params) for i in range(len(train_idx))])
        F_pred_test = np.array([rot_model.force(pos_H_test[i], rot_model.params) for i in range(len(test_idx))])
        
        rot_metrics, rot_preds = evaluate_fold(
            E_pred_train, E_pred_test, F_pred_train, F_pred_test,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["rotationally_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": rot_metrics,
            "history": rot_history,
        })
        print(f"    Energy R²: {rot_metrics['E_r2']:.4f}, Force R²: {rot_metrics['F_r2']:.4f}")
        
        # ==================== 2. Non-Equivariant ====================
        print(f"\n  [2/4] Non-Equivariant QML...")
        non_eq_model = NonEquivariantQML(n_qubits=6, depth=4, seed=fold_seed)
        
        non_eq_history = train_non_equivariant(
            non_eq_model, dist_train, jnp.array(E_train), F_grad_train,
            n_epochs=n_epochs, lr=1e-3, lambda_E=2.0, lambda_F=1.0
        )
        
        # Evaluate
        def non_eq_energy(dists, params):
            raw = non_eq_model.circuit(dists, params)
            return params["head_scale"] * raw + params["head_bias"]
        
        def non_eq_force(dists, params):
            return -jax.grad(non_eq_energy, argnums=0)(dists, params)
        
        E_pred_train_ne = np.array([non_eq_energy(dist_train[i], non_eq_model.params) for i in range(len(train_idx))])
        E_pred_test_ne = np.array([non_eq_energy(dist_test[i], non_eq_model.params) for i in range(len(test_idx))])
        F_pred_train_ne = np.array([non_eq_force(dist_train[i], non_eq_model.params) for i in range(len(train_idx))])
        F_pred_test_ne = np.array([non_eq_force(dist_test[i], non_eq_model.params) for i in range(len(test_idx))])
        
        # Expand forces to match shape
        F_pred_train_ne_full = np.zeros((len(train_idx), 3, 3))
        F_pred_test_ne_full = np.zeros((len(test_idx), 3, 3))
        for i in range(len(train_idx)):
            for j in range(3):
                r_vec = positions[train_idx[i], j+1] - positions[train_idx[i], 0]
                r_norm = np.linalg.norm(r_vec)
                if r_norm > 1e-8:
                    F_pred_train_ne_full[i, j] = -F_pred_train_ne[i, j] * r_vec / r_norm
        for i in range(len(test_idx)):
            for j in range(3):
                r_vec = positions[test_idx[i], j+1] - positions[test_idx[i], 0]
                r_norm = np.linalg.norm(r_vec)
                if r_norm > 1e-8:
                    F_pred_test_ne_full[i, j] = -F_pred_test_ne[i, j] * r_vec / r_norm
        
        non_eq_metrics, non_eq_preds = evaluate_fold(
            E_pred_train_ne, E_pred_test_ne, F_pred_train_ne_full, F_pred_test_ne_full,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["non_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": non_eq_metrics,
            "history": non_eq_history,
        })
        print(f"    Energy R²: {non_eq_metrics['E_r2']:.4f}, Force R²: {non_eq_metrics['F_r2']:.4f}")
        
        # ==================== 3. Graph Permutation Equivariant ====================
        print(f"\n  [3/4] Graph Permutation Equivariant QML...")
        graph_model = GraphPermutationEquivariantQML(n_qubits=6, depth=4, seed=fold_seed)
        
        # Use raw positions for training (forces on H atoms only)
        F_train_H = F_train  # Already (N, 3, 3) for H atoms
        
        graph_history = train_graph_permutation(
            graph_model, pos_full_train, jnp.array(E_train), jnp.array(F_train_H),
            n_epochs=n_epochs, lr=0.01
        )
        
        # Evaluate using raw positions
        def graph_force_single(coords, params):
            return -jax.grad(lambda c, p: graph_model.circuit(c, p), argnums=0)(coords, params)
        
        vec_force_graph = jax.vmap(graph_force_single, in_axes=(0, None), out_axes=0)
        
        E_pred_train_g = np.array(graph_model.vec_circuit(pos_full_train, graph_model.params))
        E_pred_test_g = np.array(graph_model.vec_circuit(pos_full_test, graph_model.params))
        
        # Forces directly from positions (H atoms only)
        F_pred_train_g_all = np.array(vec_force_graph(pos_full_train, graph_model.params))
        F_pred_test_g_all = np.array(vec_force_graph(pos_full_test, graph_model.params))
        F_pred_train_g_full = F_pred_train_g_all[:, 1:, :]  # H atoms only
        F_pred_test_g_full = F_pred_test_g_all[:, 1:, :]
        
        graph_metrics, graph_preds = evaluate_fold(
            E_pred_train_g, E_pred_test_g, F_pred_train_g_full, F_pred_test_g_full,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["graph_permutation_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": graph_metrics,
            "history": graph_history,
        })
        print(f"    Energy R²: {graph_metrics['E_r2']:.4f}, Force R²: {graph_metrics['F_r2']:.4f}")
        
        # ==================== 4. Classical Equivariant ====================
        print(f"\n  [4/4] Classical Rotationally Equivariant NN...")
        classical_model = ClassicalRotationallyEquivariantNN(hidden_dims=[128, 128, 64], seed=fold_seed)
        
        classical_history = train_classical_equivariant(
            classical_model, pos_full_train, jnp.array(E_train), jnp.array(F_train),
            n_epochs=n_epochs, lr=0.003, lambda_E=1.0, lambda_F=2.0, warmup_frac=0.3
        )
        
        # Evaluate
        E_pred_train_c = np.array(classical_model.vec_energy(pos_full_train, classical_model.params))
        E_pred_test_c = np.array(classical_model.vec_energy(pos_full_test, classical_model.params))
        F_pred_train_c = np.array(classical_model.vec_force(pos_full_train, classical_model.params))[:, 1:, :]
        F_pred_test_c = np.array(classical_model.vec_force(pos_full_test, classical_model.params))[:, 1:, :]
        
        classical_metrics, classical_preds = evaluate_fold(
            E_pred_train_c, E_pred_test_c, F_pred_train_c, F_pred_test_c,
            E_train, E_test, F_train, F_test,
            energy_scaler, force_scaler
        )
        
        results["classical_equivariant"]["folds"].append({
            "fold": fold_idx + 1,
            "metrics": classical_metrics,
            "history": classical_history,
        })
        print(f"    Energy R²: {classical_metrics['E_r2']:.4f}, Force R²: {classical_metrics['F_r2']:.4f}")
    
    # Compute summary statistics across folds
    print(f"\n{'='*80}")
    print("COMPUTING SUMMARY STATISTICS")
    print(f"{'='*80}")
    
    metrics_keys = ["E_r2", "E_mae_Ha", "E_rmse_Ha", "F_r2", "F_mae", "F_rmse"]
    
    for method in method_names:
        folds_data = results[method]["folds"]
        summary = {}
        
        for metric in metrics_keys:
            values = [fold["metrics"][metric] for fold in folds_data]
            summary[metric] = {
                "mean": float(np.mean(values)),
                "std": float(np.std(values)),
                "min": float(np.min(values)),
                "max": float(np.max(values)),
                "values": values,
            }
        
        # Compute coefficient of variation (CV) for generalizability measure
        for metric in ["E_r2", "F_r2"]:
            values = summary[metric]["values"]
            mean = summary[metric]["mean"]
            if mean != 0:
                summary[f"{metric}_cv"] = float(np.std(values) / abs(mean))  # CV
            else:
                summary[f"{metric}_cv"] = float('inf')
        
        results[method]["summary"] = summary
    
    # Print summary table
    print("\n" + "="*100)
    print("K-FOLD CROSS-VALIDATION SUMMARY")
    print("="*100)
    
    print(f"\n{'Method':<35} {'E_R² (mean±std)':<20} {'F_R² (mean±std)':<20} {'E_CV':<10} {'F_CV':<10}")
    print("-"*100)
    
    method_labels = {
        "rotationally_equivariant": "Rot. Equiv. QML",
        "non_equivariant": "Non-Equiv. QML",
        "graph_permutation_equivariant": "Graph Perm. QML",
        "classical_equivariant": "Classical Equiv. NN",
    }
    
    for method in method_names:
        summary = results[method]["summary"]
        e_r2 = summary["E_r2"]
        f_r2 = summary["F_r2"]
        e_cv = summary.get("E_r2_cv", 0)
        f_cv = summary.get("F_r2_cv", 0)
        
        print(f"{method_labels[method]:<35} "
              f"{e_r2['mean']:.4f}±{e_r2['std']:.4f}       "
              f"{f_r2['mean']:.4f}±{f_r2['std']:.4f}       "
              f"{e_cv:.4f}     {f_cv:.4f}")
    
    print("="*100)
    print("CV = Coefficient of Variation (lower = more consistent across folds)")
    
    # Save results
    results_path = os.path.join(output_dir, "kfold_results.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # Save NPZ for easy loading
    npz_data = {}
    for method in method_names:
        for metric in metrics_keys:
            key = f"{method}_{metric}"
            npz_data[key] = np.array(results[method]["summary"][metric]["values"])
    
    npz_path = os.path.join(output_dir, "kfold_metrics.npz")
    np.savez(npz_path, **npz_data)
    print(f"Metrics saved to: {npz_path}")
    
    return results


def main(k_folds=5, n_epochs=300, output_dir="kfold_results", 
         data_dir="eqnn_force_field_data_nh3_new", seed=42):
    """
    Main function for k-fold cross-validation.
    
    Args:
        k_folds: Number of folds (default: 5)
        n_epochs: Epochs per fold (default: 300)
        output_dir: Output directory
        data_dir: Data directory
        seed: Random seed
    
    Returns:
        Results dictionary
    """
    return run_kfold_cv(k_folds, n_epochs, data_dir, output_dir, seed)


if __name__ == "__main__":
    import sys
    
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter. Call main() with parameters:")
        print("  results = main(k_folds=5, n_epochs=300, output_dir='kfold_results')")
    else:
        parser = argparse.ArgumentParser(description="NH₃ K-Fold Cross-Validation Comparison")
        parser.add_argument("--k_folds", type=int, default=5, help="Number of folds")
        parser.add_argument("--n_epochs", type=int, default=300, help="Epochs per fold")
        parser.add_argument("--output_dir", type=str, default="kfold_results", help="Output directory")
        parser.add_argument("--data_dir", type=str, default="eqnn_force_field_data_nh3_new", help="Data directory")
        parser.add_argument("--seed", type=int, default=42, help="Random seed")
        
        args = parser.parse_args()
        
        main(
            k_folds=args.k_folds,
            n_epochs=args.n_epochs,
            output_dir=args.output_dir,
            data_dir=args.data_dir,
            seed=args.seed
        )



Running in Jupyter. Call main() with parameters:
  results = main(k_folds=5, n_epochs=300, output_dir='kfold_results')


In [2]:
results = main(k_folds=5, n_epochs=300, output_dir='kfold_results')




NH₃ K-FOLD CROSS-VALIDATION COMPARISON
K-Folds: 5
Epochs per fold: 300
Data directory: eqnn_force_field_data_nh3_new
Output directory: kfold_results

Loading NH₃ dataset...
  Loaded 2400 samples
  Positions shape: (2400, 4, 3)
  Energy shape: (2400,)
  Forces shape: (2400, 4, 3)

FOLD 1/5
  Train samples: 1920, Test samples: 480

  [1/4] Rotationally Equivariant QML...
    Energy R²: 0.9728, Force R²: 0.8772

  [2/4] Non-Equivariant QML...
    Energy R²: 0.0600, Force R²: 0.2080

  [3/4] Graph Permutation Equivariant QML...
    Energy R²: 0.9363, Force R²: 0.9384

  [4/4] Classical Rotationally Equivariant NN...
    Energy R²: 0.9764, Force R²: 0.9522

FOLD 2/5
  Train samples: 1920, Test samples: 480

  [1/4] Rotationally Equivariant QML...
    Energy R²: 0.9609, Force R²: 0.8121

  [2/4] Non-Equivariant QML...
    Energy R²: 0.1453, Force R²: 0.1908

  [3/4] Graph Permutation Equivariant QML...
    Energy R²: 0.9316, Force R²: 0.9420

  [4/4] Classical Rotationally Equivariant NN...