# Reset the Mean Layers
Our implementation of SchNet does not use atomrefs, but that is equivalent to normalizing by mean given

In [1]:
from fff.learning.gc.ase import SchnetCalculator
from fff.learning.gc.models import load_pretrained_model
from ase.calculators.psi4 import Psi4
from ase.db import connect
from ase import Atoms, units, build
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


Configuration

In [2]:
multiplicity = {'H': 2, 'He': 1, 'Li': 2, 'C': 3, 'N': 4, 'O': 3, 'F': 2, 'Si': 3, 'P': 4, 'S': 3, 'Cl': 2, 'Br': 2}

## Load in the model
We're going to get it in its present state

In [3]:
model = load_pretrained_model('../starting-weights/finetune_ttm_alstep_023.pt', std=units.kcal / units.mol, mean=0)

## Make a calculator to show how poorly it works
We should be off by a large constant factor

In [4]:
water = build.molecule('H2O')
methane = build.molecule('CH4')

In [5]:
psi4_eng_h2o = Psi4(method='pbe0-d3', basis='aug-cc-pvdz').get_potential_energy(water)
psi4_eng_ch4 = Psi4(method='pbe0-d3', basis='aug-cc-pvdz').get_potential_energy(methane)

  Threads set to 1 by Python driver.
  Threads set to 1 by Python driver.


  jobrec = qcng.compute(
  jobrec = qcng.compute(


In [6]:
spk_calc = SchnetCalculator(model, 'cpu')
spk_eng_h2o = spk_calc.get_potential_energy(water)
spk_eng_ch4 = spk_calc.get_potential_energy(methane)

In [7]:
print(f'Energy of water - Psi4: {psi4_eng_h2o:.2f} eV, SchNet {spk_eng_h2o:.2f} eV')
print(f'Energy of methane - Psi4: {psi4_eng_ch4:.2f} eV, SchNet {spk_eng_ch4:.2f} eV')

Energy of water - Psi4: -2077.86 eV, SchNet -0.00 eV
Energy of methane - Psi4: -1100.98 eV, SchNet 2.45 eV


We're very off. TTM does not have a "per-atom energy" but DFT does

## Get the isolated atom energies
Used to normalize the energies of SchNetPack models

In [8]:
isolated_eng = {}
for elem, ind in multiplicity.items():
    atoms = Atoms(symbols=[elem], positions=[[0,0,0]])
    psi4 = Psi4(atoms=atoms,
                method='pbe0', basis='aug-cc-pvdz', 
                reference='uhf',
                multiplicity=multiplicity[elem])   
    atoms.set_calculator(psi4)
    isolated_eng[atoms.get_atomic_numbers()[0]] = atoms.get_potential_energy()

## Update these values in the network
The atomrefs are stored as an "Embedding" layer, which we can update manually

In [9]:
with torch.no_grad():
    for z, val in isolated_eng.items():
        model.atom_ref.weight[z] = val

In [10]:
spk_calc = SchnetCalculator(model, 'cpu')
spk_eng_h2o = spk_calc.get_potential_energy(water)
spk_eng_ch4 = spk_calc.get_potential_energy(methane)

In [11]:
print(f'Energy of water - Psi4: {psi4_eng_h2o:.2f} eV, SchNet {spk_eng_h2o:.2f} eV')
print(f'Energy of methane - Psi4: {psi4_eng_ch4:.2f} eV, SchNet {spk_eng_ch4:.2f} eV')

Energy of water - Psi4: -2077.86 eV, SchNet -2068.11 eV
Energy of methane - Psi4: -1100.98 eV, SchNet -1080.65 eV


Much closer, but not quite. The original TTM potential has the energy of an isolated water as 0, which is (as we see here) non-zero in DFT.

We are going to correct for the absence of bond energies by updating the atom refs of O and C. These atoms only appear bonded to H, so this is will make the molecules correct. This will break if we have molecules besides water or methane, but it is OK for now.

In [12]:
per_water_diff = (psi4_eng_h2o - spk_eng_h2o)
per_methane_diff = (psi4_eng_ch4 - spk_eng_ch4)

In [13]:
with torch.no_grad():
    model.atom_ref.weight[8] += per_water_diff
    model.atom_ref.weight[6] += per_methane_diff

In [14]:
spk_calc = SchnetCalculator(model, 'cpu')
spk_eng_h2o = spk_calc.get_potential_energy(water)
spk_eng_ch4 = spk_calc.get_potential_energy(methane)

In [15]:
print(f'Energy of water - Psi4: {psi4_eng_h2o:.2f} eV, SchNet {spk_eng_h2o:.2f} eV')
print(f'Energy of methane - Psi4: {psi4_eng_ch4:.2f} eV, SchNet {spk_eng_ch4:.2f} eV')

Energy of water - Psi4: -2077.86 eV, SchNet -2077.86 eV
Energy of methane - Psi4: -1100.98 eV, SchNet -1100.98 eV


We're now right on, by definition. 

## Save Updated Model
For us to use later

In [16]:
torch.save(model, 'starting-model')