# MLIP Distillation Tutorial

Learn how to distill Machine Learning Interatomic Potentials (ORB) into classical GULP potentials.

**Workflow:**
1. Collect training data from structures using ORB
2. Combine manual and explored structures
3. Fit GULP Buckingham + QEq or Shell potential
4. Validate the fitted potential

**Charge Models:**
- **QEq (Part A)**: Charge equilibration - good for reactive/flexible bonding
- **Shell (Part B)**: Core-shell polarizable ions - best for oxides/halides

## Setup

In [None]:
import os
from pathlib import Path
import numpy as np

# Set GULP paths (modify for your system)
os.environ["GULP_EXE"] = "/path/to/gulp"  # Change this!
os.environ["GULP_LIB"] = "/path/to/gulp/Libraries"  # Change this!

from ase.io import read
from ase.db import connect

## Step 1: Collect Data from Existing Structures

Use `MLIPCollector` with ORB calculator to generate training data.

In [None]:
from ggen import get_orb_calculator, MLIPCollector, CollectionConfig

# Create ORB calculator (use 'cpu' if no GPU)
calc = get_orb_calculator(device="cuda")

# Configure collection
config = CollectionConfig(
    db_path="manual_training.db",
    md_temperatures=[300, 600, 1000],
    md_steps_per_temp=100,
    md_save_interval=10,
    rattle_stdev=[0.01, 0.05],
    rattle_n_configs=5,
)

collector = MLIPCollector(calculator=calc, config=config)

In [None]:
# Load your CIF files
cif_folder = Path("./cif_files")  # Change to your folder
cif_files = list(cif_folder.glob("*.cif"))

print(f"Found {len(cif_files)} CIF files")

for cif_file in cif_files[:3]:  # Process first 3 for demo
    atoms = read(cif_file)
    name = cif_file.stem
    
    print(f"\nProcessing: {name}")
    
    # Single point
    collector.collect_single_point(atoms, name=f"{name}_sp")
    
    # Optimization
    collector.collect_optimization(atoms, name=f"{name}_opt", optimize_cell=True)
    
    # MD sampling
    collector.collect_md_temperatures(atoms, name=f"{name}_md")
    
    # Rattling
    collector.collect_rattling(atoms, name=f"{name}_rattle")

print(collector.stats.summary())

## Step 2: Prepare Fitting Targets

In [None]:
from ggen import FitTarget

db = connect("manual_training.db")

# Filter by max force
MAX_FORCE = 50.0  # eV/Å

targets = []
for row in db.select():
    if row.get("max_force", 0) < MAX_FORCE:
        atoms = row.toatoms()
        forces = row.data.get("forces") if hasattr(row, "data") else None
        
        targets.append(FitTarget(
            name=f"s_{row.id}",
            atoms=atoms,
            energy=row.get("total_energy"),
            forces=forces,
            energy_weight=1.0,
            force_weight=0.1,
        ))

print(f"Prepared {len(targets)} valid training targets")

---
# Part A: Buckingham + QEq Model

QEq (Charge Equilibration) dynamically calculates atomic charges based on electronegativity equalization.

## A1: Auto-load QEq Parameters

GGen provides a database of QEq parameters (103 elements from Open Babel). Use `build_qeq_config()` to auto-generate parameters.

In [None]:
from ggen import (
    PotentialConfig, ChargeModel, PotentialType,
    BuckinghamParams, build_qeq_config, QEQ_PARAMS_DATABASE
)

# Check available QEq parameters
elements = ["Nb", "W", "O"]
for el in elements:
    p = QEQ_PARAMS_DATABASE[el]
    print(f"{el}: chi={p.chi:.3f} eV, mu={p.mu:.3f} eV")

In [None]:
# Auto-build QEq config from element list
qeq_params = build_qeq_config(["Nb", "W", "O"])

config_qeq = PotentialConfig(
    name="NbWO_qeq",
    charge_model=ChargeModel.QEQ,
    potential_type=PotentialType.BUCKINGHAM,
    buckingham={
        ("Nb", "O"): BuckinghamParams(A=5000, rho=0.35, C=0),
        ("W", "O"): BuckinghamParams(A=6000, rho=0.33, C=0),
        ("O", "O"): BuckinghamParams(A=22000, rho=0.15, C=28),
    },
    qeq_params=qeq_params,  # Auto-loaded!
    fit_buckingham=True,
    fit_charges=True,
)

print(f"QEq config ready with {len(qeq_params)} elements")

In [None]:
# Or manually specify custom QEq parameters
from ggen import QEqParams

custom_qeq = {
    "Nb": QEqParams("Nb", chi=3.55, mu=6.76),
    "W": QEqParams("W", chi=4.63, mu=6.62),
    "O": QEqParams("O", chi=8.74, mu=13.36),
}

## A2: Fit QEq Potential

In [None]:
from ggen import GULPFitter

fitter = GULPFitter(
    gulp_command=os.environ["GULP_EXE"],
    gulp_lib=os.environ["GULP_LIB"],
    verbose=True,
)

result_qeq = fitter.fit(
    config=config_qeq,
    targets=targets[:100],
    method="dual_annealing",
    maxiter=200,
)

print(f"\nQEq Fitting: {'converged' if result_qeq.converged else 'failed'}")
print(f"Objective: {result_qeq.objective_value:.4f}")

fitter.save_library(result_qeq.config, "NbWO_qeq.lib")

---
# Part B: Buckingham + Shell Model

Shell model uses core-shell polarizable ions with spring constants. Best for:
- Oxides (MgO, SiO₂, Al₂O₃)
- Halides (NaCl, CaF₂)
- High-frequency dielectric properties

## B1: Configure Shell Model

Shell model splits ionic charge into core + shell, connected by a spring.

In [None]:
from ggen import ShellParams

# Shell parameters for Nb-W-O system
# Core and shell charges must sum to formal charge
shell_params = {
    # Oxygen: formal charge -2, split into core and shell
    "O": ShellParams(
        element="O",
        core_charge=0.86902,       # Positive core
        shell_charge=-2.86902,     # Negative shell (core + shell = -2)
        spring_k=74.92,            # Spring constant (eV/Å²)
        spring_k_bounds=(10, 200),
    ),
}

# Fixed charges for cations (no shell, just core)
fixed_charges = {
    "Nb": 5.0,   # Nb5+
    "W": 6.0,    # W6+
}

print("Shell model parameters:")
for el, sp in shell_params.items():
    print(f"  {el}: core={sp.core_charge:.3f}, shell={sp.shell_charge:.3f}, k={sp.spring_k:.1f}")

In [None]:
config_shell = PotentialConfig(
    name="NbWO_shell",
    charge_model=ChargeModel.SHELL,
    potential_type=PotentialType.BUCKINGHAM,
    buckingham={
        # Metal-O interactions (with shell)
        ("Nb", "O"): BuckinghamParams(
            A=1509.0, rho=0.3625, C=0,
            A_bounds=(500, 5000), rho_bounds=(0.25, 0.45)
        ),
        ("W", "O"): BuckinghamParams(
            A=1800.0, rho=0.3400, C=0,
            A_bounds=(500, 5000), rho_bounds=(0.25, 0.45)
        ),
        # O-O interaction (shell-shell)
        ("O", "O"): BuckinghamParams(
            A=22764.0, rho=0.149, C=27.88,
            A_bounds=(10000, 50000), rho_bounds=(0.1, 0.2)
        ),
    },
    shell_params=shell_params,
    fixed_charges=fixed_charges,
    fit_buckingham=True,
    fit_charges=False,  # Keep charges fixed for shell model
)

print("Shell model configured!")

## B2: Fit Shell Potential

In [None]:
result_shell = fitter.fit(
    config=config_shell,
    targets=targets[:100],
    method="dual_annealing",
    maxiter=200,
)

print(f"\nShell Fitting: {'converged' if result_shell.converged else 'failed'}")
print(f"Objective: {result_shell.objective_value:.4f}")

fitter.save_library(result_shell.config, "NbWO_shell.lib")

---
## Comparison: QEq vs Shell

Compare the fitted potentials.

In [None]:
from ggen import get_gulp_calculator
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt

# Create calculators
calc_qeq = get_gulp_calculator(library="NbWO_qeq.lib", keywords="conp gradients qeq")
calc_shell = get_gulp_calculator(library="NbWO_shell.lib", keywords="conp gradients")

orb_e, qeq_e, shell_e = [], [], []

for target in targets[:20]:
    atoms = target.atoms.copy()
    n = len(atoms)
    orb_e.append(target.energy / n)
    
    # QEq
    atoms.calc = calc_qeq
    try:
        qeq_e.append(atoms.get_potential_energy() / n)
    except:
        qeq_e.append(np.nan)
    
    # Shell
    atoms.calc = calc_shell
    try:
        shell_e.append(atoms.get_potential_energy() / n)
    except:
        shell_e.append(np.nan)

# Calculate MAE
orb_e = np.array(orb_e)
qeq_valid = ~np.isnan(qeq_e)
shell_valid = ~np.isnan(shell_e)

mae_qeq = mean_absolute_error(orb_e[qeq_valid], np.array(qeq_e)[qeq_valid])
mae_shell = mean_absolute_error(orb_e[shell_valid], np.array(shell_e)[shell_valid])

print(f"QEq MAE: {mae_qeq:.4f} eV/atom")
print(f"Shell MAE: {mae_shell:.4f} eV/atom")

In [None]:
# Parity plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# QEq
ax1.scatter(orb_e[qeq_valid], np.array(qeq_e)[qeq_valid], alpha=0.6, c='blue')
lims = [orb_e.min(), orb_e.max()]
ax1.plot(lims, lims, 'k--', alpha=0.5)
ax1.set_xlabel('ORB (eV/atom)')
ax1.set_ylabel('GULP (eV/atom)')
ax1.set_title(f'QEq Model (MAE={mae_qeq:.4f})')

# Shell
ax2.scatter(orb_e[shell_valid], np.array(shell_e)[shell_valid], alpha=0.6, c='red')
ax2.plot(lims, lims, 'k--', alpha=0.5)
ax2.set_xlabel('ORB (eV/atom)')
ax2.set_ylabel('GULP (eV/atom)')
ax2.set_title(f'Shell Model (MAE={mae_shell:.4f})')

plt.tight_layout()
plt.show()

---
## Summary

### QEq Model
- ✅ Auto-load parameters with `build_qeq_config(elements)`
- ✅ 103 elements in database (`QEQ_PARAMS_DATABASE`)
- ✅ Good for systems with variable oxidation states

### Shell Model
- ✅ Better for polarizable ions (oxides, halides)
- ✅ Captures high-frequency dielectric response
- ⚠️ Requires careful core/shell charge partitioning

### Choosing a Model
| System | Recommended |
|--------|-------------|
| Metal oxides | Shell |
| Halides | Shell |
| Metals/alloys | Fixed charges |
| Reactive systems | QEq |
| Mixed-valence | QEq |