# MLIP Distillation Tutorial

Learn how to distill Machine Learning Interatomic Potentials (ORB) into classical GULP potentials.

**Workflow:**
1. Collect diverse training data using ORB
2. Filter similar structures for dataset quality
3. Fit GULP potentials (QEq or Shell model)
4. Validate the fitted potential

**New Features:**
- `collect_comprehensive()` - all-in-one data collection
- `DuplicateFilter` - remove similar structures
- `fit_native()` - fast GULP internal fitting

## Setup

In [None]:
import os
from pathlib import Path
import numpy as np

# Set GULP paths (modify for your system)
os.environ["GULP_EXE"] = "/path/to/gulp"  # Change this!
os.environ["GULP_LIB"] = "/path/to/gulp/Libraries"  # Change this!

from ase.io import read
from ase.db import connect

---
## Step 1: Comprehensive Data Collection

Use `collect_comprehensive()` for diverse training data:
- Original structure MD
- Volume-scaled MD (compression/expansion)
- Shear-strained MD
- Rattling perturbations
- Strain perturbations

In [None]:
from ggen import get_orb_calculator, MLIPCollector, CollectionConfig

# Create ORB calculator
calc = get_orb_calculator(device="cuda")

# Configure collection
config = CollectionConfig(
    db_path="comprehensive_training.db",
    md_temperatures=[300, 600, 1000],
    md_steps_per_temp=50,
    md_save_interval=5,
    rattle_stdev=[0.01, 0.05],
    rattle_n_configs=5,
)

collector = MLIPCollector(calculator=calc, config=config)

In [None]:
# Load CIF files
cif_folder = Path("./cif_files")
cif_files = list(cif_folder.glob("*.cif"))

for cif_file in cif_files[:3]:
    atoms = read(cif_file)
    name = cif_file.stem
    
    # Comprehensive collection (all methods at once!)
    collector.collect_comprehensive(
        atoms,
        name=name,
        include_original=True,   # MD on original
        include_scaled=True,     # Volume scaling + MD
        include_sheared=True,    # Shear strain + MD
        include_rattling=True,   # Random displacements
        include_strain=True,     # Strain perturbations
        scale_factors=[0.90, 0.95, 1.05, 1.10],
    )

print(collector.stats.summary())

---
## Step 2: Filter Similar Structures

Use `DuplicateFilter` to remove redundant structures and improve dataset diversity.

In [None]:
from ggen import DuplicateFilter, filter_training_data

# Quick way
result = filter_training_data(
    "comprehensive_training.db",
    "filtered_training.db",
    method="soap",      # or "coulomb"
    threshold=0.95,     # similarity threshold
)

print(result.summary())

In [None]:
# Or use DuplicateFilter directly for more control
filter = DuplicateFilter(
    method="soap",
    threshold=0.90,  # More strict filtering
)

# Get similarity matrix (for visualization)
db = connect("filtered_training.db")
atoms_list = [row.toatoms() for row in list(db.select())[:20]]
sim_matrix = filter.get_similarity_matrix(atoms_list)

import matplotlib.pyplot as plt
plt.imshow(sim_matrix, cmap='viridis')
plt.colorbar(label='Similarity')
plt.title('Structure Similarity Matrix')
plt.show()

---
## Step 3: Prepare Fitting Targets

In [None]:
from ggen import FitTarget

db = connect("filtered_training.db")
MAX_FORCE = 50.0

targets = []
for row in db.select():
    if row.get("max_force", 0) < MAX_FORCE:
        atoms = row.toatoms()
        targets.append(FitTarget(
            name=f"s_{row.id}",
            atoms=atoms,
            energy=row.get("total_energy"),
            energy_weight=1.0,
        ))

print(f"Prepared {len(targets)} training targets")

---
## Step 4: Configure Potential (QEq with Auto-Loaded Parameters)

In [None]:
from ggen import (
    PotentialConfig, ChargeModel, PotentialType,
    BuckinghamParams, build_qeq_config, QEQ_PARAMS_DATABASE
)

# Check available QEq parameters (103 elements)
elements = ["Nb", "W", "O"]
for el in elements:
    p = QEQ_PARAMS_DATABASE[el]
    print(f"{el}: chi={p.chi:.3f} eV, mu={p.mu:.3f} eV")

# Auto-build QEq config
qeq_params = build_qeq_config(["Nb", "W", "O"])

config_qeq = PotentialConfig(
    name="NbWO_qeq",
    charge_model=ChargeModel.QEQ,
    potential_type=PotentialType.BUCKINGHAM,
    buckingham={
        ("Nb", "O"): BuckinghamParams(A=5000, rho=0.35, C=0),
        ("W", "O"): BuckinghamParams(A=6000, rho=0.33, C=0),
        ("O", "O"): BuckinghamParams(A=22000, rho=0.15, C=28),
    },
    qeq_params=qeq_params,
    fit_buckingham=True,
    fit_charges=True,
)

---
## Step 5: Fit Potential (Native GULP Fitting - FAST!)

Use `fit_native()` for much faster fitting using GULP's internal optimizer.

In [None]:
from ggen import GULPFitter

fitter = GULPFitter(
    gulp_command=os.environ["GULP_EXE"],
    gulp_lib=os.environ["GULP_LIB"],
    verbose=True,
)

# NATIVE fitting (fast! - single GULP call)
result = fitter.fit_native(
    config=config_qeq,
    targets=targets,
    fit_cycles=100,        # GULP internal cycles
    simultaneous=True,     # Fit all structures at once
    relax_structures=True, # Optimize during fitting
)

print(f"\nFitting {'converged' if result.converged else 'failed'}")
print(f"Objective: {result.objective_value:.4e}")

In [None]:
# Compare with external scipy fitting (slower, for reference)
# result_scipy = fitter.fit(
#     config=config_qeq,
#     targets=targets[:50],
#     method="dual_annealing",
#     maxiter=100,
# )

In [None]:
# Save fitted potential
fitter.save_library(result.config, "NbWO_fitted.lib")

print("\nFitted Buckingham parameters:")
for pair, params in result.config.buckingham.items():
    print(f"  {pair}: A={params.A:.1f}, rho={params.rho:.4f}, C={params.C:.1f}")

---
## Step 6: Validation

In [None]:
from ggen import get_gulp_calculator
from sklearn.metrics import mean_absolute_error, r2_score
import matplotlib.pyplot as plt

# Create calculator with fitted potential
gulp_calc = get_gulp_calculator(
    library="NbWO_fitted.lib",
    keywords="conp gradients qeq",
)

orb_e, gulp_e = [], []

for target in targets[:30]:
    atoms = target.atoms.copy()
    n = len(atoms)
    orb_e.append(target.energy / n)
    
    atoms.calc = gulp_calc
    try:
        gulp_e.append(atoms.get_potential_energy() / n)
    except:
        gulp_e.append(np.nan)

valid = ~np.isnan(gulp_e)
mae = mean_absolute_error(np.array(orb_e)[valid], np.array(gulp_e)[valid])
r2 = r2_score(np.array(orb_e)[valid], np.array(gulp_e)[valid])

print(f"MAE: {mae:.4f} eV/atom")
print(f"R²: {r2:.4f}")

In [None]:
# Parity plot
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(np.array(orb_e)[valid], np.array(gulp_e)[valid], alpha=0.6)
lims = [min(np.array(orb_e)[valid]), max(np.array(orb_e)[valid])]
ax.plot(lims, lims, 'k--', alpha=0.5)
ax.set_xlabel('ORB Energy (eV/atom)')
ax.set_ylabel('GULP Energy (eV/atom)')
ax.set_title(f'Energy Parity (MAE={mae:.4f}, R²={r2:.4f})')
plt.tight_layout()
plt.show()

---
## Summary

### Key Functions

| Function | Description |
|----------|-------------|
| `collect_comprehensive()` | All-in-one data collection (scaled, sheared, rattled...) |
| `filter_training_data()` | Remove similar structures using SOAP/Coulomb |
| `build_qeq_config()` | Auto-load QEq params for elements |
| `fit_native()` | Fast GULP internal fitting |

### Speed Comparison

| Method | 100 targets | Notes |
|--------|-------------|-------|
| `fit()` (scipy) | ~30 min | External optimizer, many GULP calls |
| `fit_native()` | ~1 min | GULP internal optimizer, single call |

### Installation

```bash
pip install ggen[orb,fingerprint]  # ORB + SOAP/Coulomb filtering
```