# Reset the Mean Layers
Our implementation of SchNet does not use atomrefs, but that is equivalent to normalizing by mean given

In [1]:
from fff.learning.gc.ase import SchnetCalculator
from fff.simulation.utils import read_from_string
from ase.optimize import QuasiNewton
from ase.calculators.psi4 import Psi4
from ase.db import connect
from ase import Atoms, units, build
from pathlib import Path
import numpy as np
import torch
import json

  from .autonotebook import tqdm as notebook_tqdm


Configuration

In [2]:
start_name = 'dft_mctbp-25000_ttm-100k'
basis = 'aug-cc-pvtz'

## Load in the model
Get the best training from the initial TTM data to start from

In [3]:
start_dir = Path('best-models') / start_name

In [4]:
model = torch.load(start_dir / 'best_model')

## Make a calculator to show how poorly it works
We should be off by a large constant factor

In [5]:
with open('reference-energies.json') as fp:
    reference_energies = json.load(fp)

In [6]:
water = read_from_string(reference_energies['H2O'][basis]['xyz'], 'xyz')
mp2_eng_h2o = reference_energies['H2O'][basis]['energy']

Get the energies with SchNet

In [7]:
spk_calc = SchnetCalculator(model, 'cpu')
spk_eng_h2o = spk_calc.get_potential_energy(water)

In [8]:
print(f'Energy of water - Psi4: {mp2_eng_h2o:.2f} eV, SchNet {spk_eng_h2o:.2f} eV')

Energy of water - Psi4: -2075.26 eV, SchNet -2077.85 eV


We're very off. TTM does not have a "per-atom energy" but DFT does

## Get atomic reference energies
We determine them from the bond energies of H2 and H2O.

In [9]:
isolated_eng = {1: reference_energies['H2'][basis]['energy'] / 2}

In [10]:
isolated_eng[8] = reference_energies['H2O'][basis]['energy'] - isolated_eng[1] * 2

## Update these values in the network
The atomrefs are stored as an "Embedding" layer, which we can update manually

In [11]:
with torch.no_grad():
    for z, val in isolated_eng.items():
        model.atom_ref.weight[z] = val

In [12]:
spk_calc = SchnetCalculator(model, 'cpu')
spk_eng_h2o = spk_calc.get_potential_energy(water)

In [13]:
print(f'Energy of water - Psi4: {mp2_eng_h2o:.2f} eV, SchNet {spk_eng_h2o:.2f} eV')

Energy of water - Psi4: -2075.26 eV, SchNet -2075.25 eV


## Save Updated Model
For us to use later

In [14]:
torch.save(model, start_dir / f'a{basis[-3:]}-starting-model')