# MLIP Distillation Tutorial

Learn how to distill Machine Learning Interatomic Potentials (ORB) into classical GULP potentials.

**Workflow:**
1. Collect training data from structures using ORB
2. Combine manual and explored structures
3. Fit GULP Buckingham + QEq potential
4. Validate the fitted potential

## Setup

In [None]:
import os
from pathlib import Path
import numpy as np

# Set GULP paths (modify for your system)
os.environ["GULP_EXE"] = "/path/to/gulp"  # Change this!
os.environ["GULP_LIB"] = "/path/to/gulp/Libraries"  # Change this!

from ase.io import read
from ase.db import connect

## Step 1: Collect Data from Existing Structures

Use `MLIPCollector` with ORB calculator to generate training data.

In [None]:
from ggen import get_orb_calculator, MLIPCollector, CollectionConfig

# Create ORB calculator (use 'cpu' if no GPU)
calc = get_orb_calculator(device="cuda")

# Configure collection
config = CollectionConfig(
    db_path="manual_training.db",
    md_temperatures=[300, 600, 1000],
    md_steps_per_temp=100,
    md_save_interval=10,
    rattle_stdev=[0.01, 0.05],
    rattle_n_configs=5,
)

collector = MLIPCollector(calculator=calc, config=config)

In [None]:
# Load your CIF files
cif_folder = Path("./cif_files")  # Change to your folder
cif_files = list(cif_folder.glob("*.cif"))

print(f"Found {len(cif_files)} CIF files")

for cif_file in cif_files[:3]:  # Process first 3 for demo
    atoms = read(cif_file)
    name = cif_file.stem
    
    print(f"\nProcessing: {name}")
    
    # Single point
    collector.collect_single_point(atoms, name=f"{name}_sp")
    
    # Optimization
    collector.collect_optimization(atoms, name=f"{name}_opt", optimize_cell=True)
    
    # MD sampling
    collector.collect_md_temperatures(atoms, name=f"{name}_md")
    
    # Rattling
    collector.collect_rattling(atoms, name=f"{name}_rattle")

print(collector.stats.summary())

## Step 2: Explore New Structures with GGen (Optional)

Generate additional structures through exploration.

In [None]:
from ggen import GGen, ChemistryExplorer, StructureDatabase

# Database for exploration
explore_db = StructureDatabase("exploration.db")

# Explorer
explorer = ChemistryExplorer(
    calculator=calc,
    database=explore_db,
)

# Explore small compositions
result = explorer.explore(
    elements=["Nb", "W", "O"],
    max_atoms=8,
    num_trials=10,
)

print(f"Found {result.num_successful} structures")

In [None]:
# Collect data from explored structures
collector2 = MLIPCollector(
    calculator=calc,
    db_path="exploration_training.db",
    config=config,
)

explored_structures = explore_db.get_structures_by_system("Nb-W-O", valid_only=True)

for s in explored_structures[:10]:
    atoms = s.get_structure().to_ase_atoms()
    collector2.collect_all([atoms], name_prefix=f"explore_{s.formula}")

## Step 3: Merge Databases

In [None]:
merged_db = connect("merged_training.db")

# Merge manual structures
if Path("manual_training.db").exists():
    db1 = connect("manual_training.db")
    for row in db1.select():
        atoms = row.toatoms()
        merged_db.write(atoms, source="manual", **row.key_value_pairs)
    print(f"Added {len(db1)} from manual")

# Merge explored structures
if Path("exploration_training.db").exists():
    db2 = connect("exploration_training.db")
    for row in db2.select():
        atoms = row.toatoms()
        merged_db.write(atoms, source="exploration", **row.key_value_pairs)
    print(f"Added {len(db2)} from exploration")

print(f"Total: {len(merged_db)} structures")

## Step 4: Prepare Fitting Targets

In [None]:
from ggen import FitTarget

db = connect("merged_training.db")

# Filter by max force
MAX_FORCE = 50.0  # eV/Å

targets = []
for row in db.select():
    if row.get("max_force", 0) < MAX_FORCE:
        atoms = row.toatoms()
        forces = row.data.get("forces") if hasattr(row, "data") else None
        
        targets.append(FitTarget(
            name=f"s_{row.id}",
            atoms=atoms,
            energy=row.get("total_energy"),
            forces=forces,
            energy_weight=1.0,
            force_weight=0.1,
        ))

print(f"Prepared {len(targets)} valid training targets")

## Step 5: Configure Potential

In [None]:
from ggen import (
    PotentialConfig, ChargeModel, PotentialType,
    BuckinghamParams, QEqParams
)

# Nb-W-O with Buckingham + QEq
config = PotentialConfig(
    name="NbWO_distilled",
    charge_model=ChargeModel.QEQ,
    potential_type=PotentialType.BUCKINGHAM,
    buckingham={
        ("Nb", "O"): BuckinghamParams(
            A=5000, rho=0.35, C=0,
            A_bounds=(1000, 20000), rho_bounds=(0.2, 0.5)
        ),
        ("W", "O"): BuckinghamParams(
            A=6000, rho=0.33, C=0,
            A_bounds=(1000, 20000), rho_bounds=(0.2, 0.5)
        ),
        ("O", "O"): BuckinghamParams(
            A=22000, rho=0.15, C=28,
            A_bounds=(10000, 50000), rho_bounds=(0.1, 0.25)
        ),
    },
    qeq_params={
        "Nb": QEqParams("Nb", chi=3.0, mu=6.0),
        "W": QEqParams("W", chi=4.0, mu=7.0),
        "O": QEqParams("O", chi=8.5, mu=13.0),
    },
    fit_buckingham=True,
    fit_charges=True,
)

print("Potential configured:")
print(f"  Charge model: {config.charge_model.value}")
print(f"  Pairs: {list(config.buckingham.keys())}")

## Step 6: Fit Potential

In [None]:
from ggen import GULPFitter

fitter = GULPFitter(
    gulp_command=os.environ["GULP_EXE"],
    gulp_lib=os.environ["GULP_LIB"],
    verbose=True,
)

# Fit using dual annealing (global optimization)
result = fitter.fit(
    config=config,
    targets=targets[:100],  # Use subset for faster demo
    method="dual_annealing",
    maxiter=200,
)

print(f"\nFitting {'converged' if result.converged else 'did not converge'}")
print(f"Final objective: {result.objective_value:.4f}")

In [None]:
# Save fitted potential
fitter.save_library(result.config, "NbWO_distilled.lib")
print("Saved: NbWO_distilled.lib")

# Print fitted parameters
print("\nFitted Buckingham parameters:")
for pair, params in result.config.buckingham.items():
    print(f"  {pair}: A={params.A:.1f}, rho={params.rho:.4f}, C={params.C:.1f}")

## Step 7: Validation

In [None]:
from ggen import get_gulp_calculator
from sklearn.metrics import mean_absolute_error, r2_score
import matplotlib.pyplot as plt

# Create calculator with fitted potential
gulp_calc = get_gulp_calculator(
    library="NbWO_distilled.lib",
    keywords="conp gradients",
)

# Compare energies
orb_energies = []
gulp_energies = []

for target in targets[:30]:
    atoms = target.atoms.copy()
    n = len(atoms)
    
    # ORB energy (already computed)
    orb_energies.append(target.energy / n)
    
    # GULP energy
    atoms.calc = gulp_calc
    try:
        gulp_e = atoms.get_potential_energy() / n
        gulp_energies.append(gulp_e)
    except:
        gulp_energies.append(np.nan)

# Filter valid
valid = ~np.isnan(gulp_energies)
orb_valid = np.array(orb_energies)[valid]
gulp_valid = np.array(gulp_energies)[valid]

mae = mean_absolute_error(orb_valid, gulp_valid)
r2 = r2_score(orb_valid, gulp_valid)

print(f"MAE: {mae:.4f} eV/atom")
print(f"R²: {r2:.4f}")

In [None]:
# Parity plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(orb_valid, gulp_valid, alpha=0.6)
lims = [min(orb_valid.min(), gulp_valid.min()), max(orb_valid.max(), gulp_valid.max())]
ax.plot(lims, lims, 'k--', alpha=0.5, label='y=x')
ax.set_xlabel('ORB Energy (eV/atom)')
ax.set_ylabel('GULP Energy (eV/atom)')
ax.set_title(f'Energy Parity (MAE={mae:.4f}, R²={r2:.4f})')
ax.legend()
plt.tight_layout()
plt.show()

## Summary

You've learned how to:
1. **Collect training data** from CIF files using ORB calculator
2. **Augment data** with MD, rattling, and strain perturbations
3. **Configure GULP potentials** with Buckingham + QEq
4. **Fit parameters** using global optimization
5. **Validate** with energy parity plots

Next steps:
- Add more training data for better accuracy
- Try Shell model instead of QEq for certain systems
- Validate with force predictions and phonon calculations