# Re-estimating E0s using a Linear System Approach

## The Problem Formulation

When fine-tuning a foundation model for interatomic potentials, we need to re-estimate the atomic reference energies (E0s) to better match our target dataset. Instead of using simple averages, we can frame this as a linear system optimization problem.

## Mathematical Formulation

For each configuration $i$ in our training set, the energy prediction error is:

$$\text{error}_i = E^{\text{true}}_i - E^{\text{predicted}}_i$$

Our hypothesis is that this error can be systematically corrected by adjusting the E0 values for each element:

$$\text{error}_i = \sum_{j} n_{ij} \times \text{correction}_j$$

Where:
- $n_{ij}$ is the number of atoms of element $j$ in configuration $i$
- $\text{correction}_j$ is the energy correction needed for element $j$

## Linear System Representation

This can be written as a matrix equation $A \mathbf{x} = \mathbf{b}$, where:

- $A_{ij}$ is the number of atoms of element $j$ in configuration $i$
- $\mathbf{x}_j$ is the correction for element $j$ (what we want to solve for)
- $\mathbf{b}_i$ is the error for configuration $i$

Visualizing the matrix for a system with $m$ configurations and $n$ elements:

$$
\begin{bmatrix}
n_{11} & n_{12} & \cdots & n_{1n} \\
n_{21} & n_{22} & \cdots & n_{2n} \\
\vdots & \vdots & \ddots & \vdots \\
n_{m1} & n_{m2} & \cdots & n_{mn} \\
\end{bmatrix}
\begin{bmatrix}
\text{correction}_1 \\
\text{correction}_2 \\
\vdots \\
\text{correction}_n \\
\end{bmatrix}
=
\begin{bmatrix}
\text{error}_1 \\
\text{error}_2 \\
\vdots \\
\text{error}_m \\
\end{bmatrix}
$$





## Solving the System

In most practical cases, this system is overdetermined (more equations than unknowns) because we have more configurations than element types. We use least squares optimization to find the corrections that minimize the sum of squared errors:

$$\min_{\mathbf{x}} ||A\mathbf{x} - \mathbf{b}||^2_2$$

The solution is:

$$\mathbf{x} = (A^T A)^{-1} A^T \mathbf{b}$$

Which we can compute using the `scipy.linalg.lstsq` function.


## Example

For a water dataset (H and O atoms), the system might look like:

$$
\begin{bmatrix}
2 & 1 \\
2 & 1 \\
2 & 1 \\
\vdots & \vdots \\
\end{bmatrix}
\begin{bmatrix}
\text{correction}_H \\
\text{correction}_O \\
\end{bmatrix}
=
\begin{bmatrix}
\text{error}_1 \\
\text{error}_2 \\
\text{error}_3 \\
\vdots \\
\end{bmatrix}
$$

The solution provides the optimal corrections to apply to foundation model E0s:

$$E0^{\text{new}}_j = E0^{\text{foundation}}_j + \text{correction}_j$$

This approach is particularly valuable when fine-tuning across different levels of theory or when dealing with datasets that have systematic energy shifts relative to the foundation model's training data.

### reestimate_e0s_linear_system function

In [1]:
import numpy as np
from ase.atoms import Atoms
from ase.calculators.calculator import Calculator
from scipy.linalg import lstsq

def reestimate_e0s_linear_system(
    foundation_model: Calculator,
    foundation_e0s: dict,
    training_configs: list,
    elements: list
) -> dict:
    """
    Re-estimate atomic reference energies (E0s) by solving a linear system
    that optimally corrects foundation model predictions.
    
    Args:
        foundation_model: Calculator object for the foundation model
        foundation_e0s: Dictionary mapping element atomic numbers to original E0 values
        training_configs: List of configurations with energy and atomic_numbers
        elements: List of element atomic numbers to consider
        
    Returns:
        Dictionary with re-estimated E0 values for each element
    """
    # filter configs without energy
    valid_configs = [config for config in training_configs if config.get_potential_energy() is not None]
    
    if not valid_configs:
        print("No configurations with energy found.")
        return foundation_e0s.copy()
    

    # A matrix: each row contains atom counts for each element
    # b vector: each entry is the prediction error for a configuration
    A = np.zeros((len(valid_configs), len(elements)))
    b = np.zeros(len(valid_configs))
    
    print(f"Solving linear system with {len(valid_configs)} equations and {len(elements)} unknowns")
    

    # - A[i,j] is the count of element j in configuration i
    # - b[i] is the error (true - predicted) for configuration i
    # - x[j] will be the energy correction for element j
    for i, config in enumerate(valid_configs):
        # Get foundation model prediction
        atoms = Atoms(
            numbers=config.get_atomic_numbers(),
            positions=config.positions,
            cell=config.cell,
            pbc=config.pbc
        )
        atoms.calc = foundation_model
        predicted_energy = atoms.get_potential_energy()
        
        # calc error
        error = config.get_potential_energy() - predicted_energy
        b[i] = error
        
        # atom counts for each element
        for j, element in enumerate(elements):
            A[i, j] = np.sum(config.get_atomic_numbers() == element)
    
    # Solve the system using least squares (handles overdetermined systems)
    try:
        corrections, residuals, rank, s = lstsq(A, b)
        
        # new E0s by adding corrections to foundation E0s
        new_e0s = {}
        for i, element in enumerate(elements):
            correction = corrections[i]
            new_e0s[element] = foundation_e0s[element] + correction
            print(f"Element {element}: foundation E0 = {foundation_e0s[element]:.4f}, correction = {correction:.4f}, new E0 = {new_e0s[element]:.4f}")
        
        # statistics about the fit
        mse_before = np.mean(b**2)
        b_after = b - A @ corrections
        mse_after = np.mean(b_after**2)
        improvement = (1 - mse_after/mse_before) * 100
        
        print(f"\nMean squared error before correction: {mse_before:.4f} eV²")
        print(f"Mean squared error after correction: {mse_after:.4f} eV²")
        print(f"Improvement: {improvement:.1f}%")
        
        if rank < len(elements):
            print(f"\nWarning: System is rank deficient (rank {rank}/{len(elements)})")
            print("Some elements may be linearly dependent or not sufficiently represented in the dataset.")
        
        return new_e0s
        
    except np.linalg.LinAlgError as e:
        print(f"Error solving the linear system: {e}")
        print("Falling back to foundation model E0s")
        return foundation_e0s.copy()

### Applying the funciton to estimate new E0s

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import ase.io as io

from mace.calculators import mace_mp
macemp = mace_mp()

Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument
Using Materials Project MACE for MACECalculator with /Users/joehart/.cache/mace/macempa0mediummodel
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.


  torch.load(f=model_path, map_location=device)


In [42]:
macemp.heads

['default']

In [83]:
E0s_revPBE_file = io.read("caco3-water/revPBE-D3/E0.xyz", index=":")
E0s_revPBE0_file = io.read("caco3-water/revPBE0-D3/E0.xyz", index=":")
E0s_MP2_file = io.read("caco3-water/MP2/E0.xyz", index=":")

E0s_revPBE = {E0s_revPBE_file[i].get_atomic_numbers()[0]: E0s_revPBE_file[i].info["energy_ref"] for i in range(len(E0s_revPBE_file))}
E0s_revPBE0 = {E0s_revPBE0_file[i].get_atomic_numbers()[0]: E0s_revPBE0_file[i].info["energy_ref"] for i in range(len(E0s_revPBE0_file))}
E0s_MP2 = {E0s_MP2_file[i].get_atomic_numbers()[0]: E0s_MP2_file[i].info["energy_ref"] for i in range(len(E0s_MP2_file))}

In [92]:
training_configs_revPBE = io.read("training_val_sets/training_set_30_revPBE_D3.xyz", index=":")
training_configs_revPBE0 = io.read("training_val_sets/training_set_30_revPBE0_D3.xyz", index=":")
training_configs_MP2 = io.read("training_val_sets/training_set_30_MP2.xyz", index=":")

In [None]:
elements = [1, 6, 8, 20]

In [78]:
foundation_E0s = {1: -3.667168021358939, 3: -3.482100566595956, 4: -4.736697230897597, 5: -7.724935420523256, 6: -8.405573550273285, 7: -7.360100452662763, 8: -7.28459863421322, 9: -4.896490881731322, 11: -2.7593613569762425, 12: -2.814047612069227, 13: -4.846881245288104, 14: -7.694793133351899, 15: -6.9632957911820235, 16: -4.672630400190884, 17: -2.8116892814008096, 18: -0.06259504416367478, 19: -2.6176454856894793, 20: -5.390461060484104, 21: -7.8857952163517675, 22: -10.268392986214433, 23: -8.665147785496703, 24: -9.233050763772013, 25: -8.304951520770791, 26: -7.0489865771593765, 27: -5.577439766222147, 28: -5.172747618813715, 29: -3.2520726958619472, 30: -1.2901611618726314, 31: -3.527082192997912, 32: -4.70845955030298, 33: -3.9765109025623238, 34: -3.886231055836541, 35: -2.5184940099633986, 36: 6.766947645687137, 37: -2.5634958965928316, 38: -4.938005211501922, 39: -10.149818838085771, 40: -11.846857579882572, 41: -12.138896361658485, 42: -8.791678800595722, 43: -8.78694939675911, 44: -7.78093221529871, 45: -6.850021409115055, 46: -4.891019073240479, 47: -2.0634296773864045, 48: -0.6395695518943755, 49: -2.7887442084286693, 50: -3.818604275441892, 51: -3.587068329278862, 52: -2.8804045971118897, 53: -1.6355986842433357, 54: 9.846723842807721, 55: -2.765284507132287, 56: -4.990956432167774, 57: -8.933684809576345, 58: -8.735591176647514, 59: -8.018966025544966, 60: -8.251491970213372, 61: -7.591719594359237, 62: -8.169659881166858, 63: -13.592664636171698, 64: -18.517523458456985, 65: -7.647396572993602, 66: -8.122981037851925, 67: -7.607787319678067, 68: -6.85029094445494, 69: -7.8268821327130365, 70: -3.584786591677161, 71: -7.455406192077973, 72: -12.796283502572146, 73: -14.108127281277586, 74: -9.354916969477486, 75: -11.387537567890853, 76: -9.621909492152557, 77: -7.324393429417677, 78: -5.3046964808341945, 79: -2.380092582080244, 80: 0.24948924158195362, 81: -2.3239789120665026, 82: -3.730042357127322, 83: -3.438792347649683, 89: -5.062878214511315, 90: -11.02462566385297, 91: -12.265613551943261, 92: -13.855648206100362, 93: -14.933092020258243, 94: -15.282826131998245}

In [93]:
correctedE0s_revPBE = reestimate_e0s_linear_system(macemp, foundation_E0s, training_configs_revPBE, elements)


Solving linear system with 30 equations and 4 unknowns
Element 1: foundation E0 = -3.6672, correction = -189.5552, new E0 = -193.2224
Element 6: foundation E0 = -8.4056, correction = -0.9671, new E0 = -9.3727
Element 8: foundation E0 = -7.2846, correction = -97.6789, new E0 = -104.9635
Element 20: foundation E0 = -5.3905, correction = -0.9671, new E0 = -6.3576

Mean squared error before correction: 2210910962.5683 eV²
Mean squared error after correction: 0.5968 eV²
Improvement: 100.0%

Some elements may be linearly dependent or not sufficiently represented in the dataset.


In [94]:
correctedE0s_revPBE0 = reestimate_e0s_linear_system(macemp, foundation_E0s, training_configs_revPBE0, elements)

Solving linear system with 30 equations and 4 unknowns
Element 1: foundation E0 = -3.6672, correction = -189.4592, new E0 = -193.1263
Element 6: foundation E0 = -8.4056, correction = -0.9666, new E0 = -9.3722
Element 8: foundation E0 = -7.2846, correction = -97.6295, new E0 = -104.9141
Element 20: foundation E0 = -5.3905, correction = -0.9666, new E0 = -6.3571

Mean squared error before correction: 2208671817.9019 eV²
Mean squared error after correction: 19.6298 eV²
Improvement: 100.0%

Some elements may be linearly dependent or not sufficiently represented in the dataset.


In [95]:
correctedE0s_MP2 = reestimate_e0s_linear_system(macemp, foundation_E0s, training_configs_MP2, elements)

Solving linear system with 30 equations and 4 unknowns
Element 1: foundation E0 = -3.6672, correction = -188.7688, new E0 = -192.4359
Element 6: foundation E0 = -8.4056, correction = -0.9631, new E0 = -9.3687
Element 8: foundation E0 = -7.2846, correction = -97.2737, new E0 = -104.5583
Element 20: foundation E0 = -5.3905, correction = -0.9631, new E0 = -6.3536

Mean squared error before correction: 2192603615.4124 eV²
Mean squared error after correction: 8.2917 eV²
Improvement: 100.0%

Some elements may be linearly dependent or not sufficiently represented in the dataset.


In [96]:
correctedE0s_revPBE, correctedE0s_revPBE0, correctedE0s_MP2

({1: -193.22235627360988,
  6: -9.372691857682717,
  8: -104.96354768256703,
  20: -6.357579367893548},
 {1: -193.1263432216216,
  6: -9.372201995172574,
  8: -104.91407156904245,
  20: -6.357089505383404},
 {1: -192.4359230170113,
  6: -9.368679443108235,
  8: -104.5582938105443,
  20: -6.353566953319065})