# Determine Offsets for SchNet
SchNet needs an reference energy for each atom and a mean/std devation for the energy

In [1]:
from sklearn.linear_model import RANSACRegressor, LinearRegression
from collections import Counter
from ase.db import connect
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
import json

Configuration

In [2]:
method = 'wb97x_dz'
ref_file = Path('reference_energies.json')

## Load in the Data
Get the energy and composition of each entry in the training set. We only store the lowest energy for each composition

In [3]:
records = {}  # name -> entry
with connect(Path('data/') / method / 'train.db') as db:
    for row in tqdm(db.select('')):
        # Count the atoms
        c = Counter(row.symbols)
        record = dict(c)
        record['name'] = row.name
        
        # Add the energy and atom count
        record['n_atoms'] = sum(c.values())
        record['energy'] = row.energy
        if row.name not in records or \
            row.energy < records[row.name]['energy']:
            records[row.name] = record
records = pd.DataFrame(records).T

3890235it [03:10, 20470.26it/s]


In [4]:
records.head(2)

Unnamed: 0,C,H,name,n_atoms,energy,N,O
C10H10,10,10,C10H10,20,-5264.883878,,
C10H10N2,10,10,C10H10N2,22,-6754.892183,2.0,


## Fit Atomic Reference Energies
Fit a linear model that predicts energy as a function of number of atoms. The coefficients are our atomic reference energies

In [5]:
elem_columns = [c for c in records.columns if len(c) < 3]   # I know the non-element columns have names with >3 characters
print(f'We found {len(elem_columns)} elements: {elem_columns}')

We found 4 elements: ['C', 'H', 'N', 'O']


Get the values and ensure 'nans' are zeros

In [6]:
x = np.array(records[elem_columns].values, dtype=float)
x[np.isnan(x)] = 0

Fit and extract coefficients

In [7]:
model = RANSACRegressor(estimator=LinearRegression(fit_intercept=False)).fit(x, records['energy'])
ref_energies = dict(zip(elem_columns, model.estimator_.coef_))
ref_energies

{'C': -518.2976535294367,
 'H': -8.206374313766219,
 'N': -744.5895241965858,
 'O': -1023.0426190845573}

## Compute the Mean and Standard Deviation
Get the mean and standard deviation of the per-atom energy

In [8]:
records['ref_energy'] = model.predict(x)

In [9]:
records['norm_energy'] = records['energy'] - records['ref_energy']

In [10]:
records.head(5)

Unnamed: 0,C,H,name,n_atoms,energy,N,O,ref_energy,norm_energy
C10H10,10,10,C10H10,20,-5264.883878,,,-5265.040278,0.156401
C10H10N2,10,10,C10H10N2,22,-6754.892183,2.0,,-6754.219327,-0.672856
C10H10N2O1,10,10,C10H10N2O1,23,-7778.299828,2.0,1.0,-7777.261946,-1.037882
C10H10N4O1,10,10,C10H10N4O1,25,-9267.21088,4.0,1.0,-9266.440994,-0.769886
C10H10O1,10,10,C10H10O1,21,-6288.633512,,1.0,-6288.082898,-0.550614


In [11]:
mean = records['norm_energy'].mean()  # TODO (wardlt): Normalize energy/atom
std = records['norm_energy'].std()

## Save them
We'll keep a JSON document with the total data in it

In [12]:
ref_data = {}
if ref_file.exists():
    ref_data = json.loads(ref_file.read_text())

In [13]:
ref_data[method] = {
    'ref_energies': ref_energies,
    'offsets': {'mean': mean, 'std': std}
}

In [14]:
ref_file.write_text(json.dumps(ref_data, indent=2))

266