

Train Neural Network Potential To Energies, Forces, and Hessians
==========================================================

This tutorial shows how to train a Neural Network Potential based on ANI-1 to energies, forces, and Hessians.


In [None]:
import torch
import torchani
import os
import math
import tqdm
import wandb
import numpy as np
import random


# Set seeds for reproducibility (seed: 7289038)
random.seed(7289038)
np.random.seed(7289038)
torch.manual_seed(7289038)

# Set the default dtype for tensors
torch.set_default_dtype(torch.float64)

# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')

Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=device)
Zeta = torch.tensor([3.2000000e+01], device=device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=device)
EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
species_order = ['H', 'C', 'N', 'O']
num_species = len(species_order)
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)


try:
    path = os.path.dirname(os.path.realpath(__file__))
except NameError:
    path = os.getcwd()
#dspath = os.path.join(path, '../dataset/ai-1x/sample.h5') # <- Original line
dspath = os.path.join(path, './molecules-RTP.h5')

config = {"max_epochs":5000, "batch_size":400, "hidden_layer_sizes":[256,64,256]}

# Change the proportions of the training, validation, and test sets as needed
training, skip, validation, test = torchani.data.load(dspath, additional_properties=('forces','hessian')
    ).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, 0, 0.1, None)

training = training.collate(config["batch_size"]).cache()
validation = validation.collate(config["batch_size"]).cache()
test = test.collate(config["batch_size"]).cache()

print('Self atomic energies: ', energy_shifter.self_energies)

Self atomic energies:  tensor([ -0.6138, -38.0539, -54.7216, -75.1908])


The code to define networks, optimizers, are mostly the same



In [3]:
aev_dim = aev_computer.aev_length

H_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 64),
    torch.nn.CELU(0.1),
    torch.nn.Linear(64, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 1),
)

C_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 64),
    torch.nn.CELU(0.1),
    torch.nn.Linear(64, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 1),
)

N_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 64),
    torch.nn.CELU(0.1),
    torch.nn.Linear(64, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 1),
)

O_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 64),
    torch.nn.CELU(0.1),
    torch.nn.Linear(64, 256),
    torch.nn.CELU(0.1),
    torch.nn.Linear(256, 1),
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])
print(nn)

ANIModel(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=

Initialize the weights and biases.

<div class="alert alert-info"><h4>Note</h4><p>Pytorch default initialization for the weights and biases in linear layers
  is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
  We initialize the weights similarly but from the normal distribution.
  The biases were initialized to zero.</p></div>

  https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear



In [4]:
def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)


nn.apply(init_params)

ANIModel(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=64, out_features=256, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=

Let's now create a pipeline of AEV Computer --> Neural Networks.



In [6]:
model = torchani.nn.Sequential(aev_computer, nn).to(device)

To use multiple GPUs for the training, we can split the data between them with the following function:

In [7]:
model = torch.nn.DataParallel(model)
model.to(device)

DataParallel(
  (module): Sequential(
    (0): AEVComputer()
    (1): ANIModel(
      (0): Sequential(
        (0): Linear(in_features=384, out_features=256, bias=True)
        (1): CELU(alpha=0.1)
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): CELU(alpha=0.1)
        (4): Linear(in_features=64, out_features=256, bias=True)
        (5): CELU(alpha=0.1)
        (6): Linear(in_features=256, out_features=1, bias=True)
      )
      (1): Sequential(
        (0): Linear(in_features=384, out_features=256, bias=True)
        (1): CELU(alpha=0.1)
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): CELU(alpha=0.1)
        (4): Linear(in_features=64, out_features=256, bias=True)
        (5): CELU(alpha=0.1)
        (6): Linear(in_features=256, out_features=1, bias=True)
      )
      (2): Sequential(
        (0): Linear(in_features=384, out_features=256, bias=True)
        (1): CELU(alpha=0.1)
        (2): Linear(in_features=256, out_featu

Here we will use Adam with weight decay for the weights and Stochastic Gradient
Descent for biases.



In [8]:
AdamW = torch.optim.AdamW([
    # H networks
    {'params': [H_network[0].weight]},
    {'params': [H_network[2].weight]},
    {'params': [H_network[4].weight]},
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight]},
    {'params': [C_network[4].weight]},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight]},
    {'params': [N_network[4].weight]},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight]},
    {'params': [O_network[4].weight]},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
    {'params': [H_network[4].bias]},
    {'params': [H_network[6].bias]},
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
    {'params': [C_network[6].bias]},
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
    {'params': [N_network[6].bias]},
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
    {'params': [O_network[6].bias]},
], lr=1e-3)

AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)

If we have a training checkpoint file saved, we can save the name of the file to `latest_checkpoint`

In [None]:
latest_checkpoint = 'torch-checkpoint-files/Training-RTP-EFH-latest.pt'

Resume training from previously saved checkpoints:



In [10]:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])

During training, we need to validate on validation set and if validation error
is better than the best, then save the new best model to a checkpoint



In [9]:
def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    energy_mse = 0.0
    force_mse = 0.0
    hessian_mse = 0.0
    n_molec = 0
    n_force_elements = 0
    n_Hessian_elements = 0
    
    for properties in validation:
        # Save the properties in variables
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).requires_grad_(True)
        true_energies = properties['energies'].to(device)
        true_forces = properties['forces'].to(device)
        true_hessian = properties['hessian'].to(device)
        
        # Predict energies from our model
        _, predicted_energies = model((species, coordinates))
        
        # Calculate the predicted forces and the Hessian from our model
        forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
        hessian = torchani.utils.hessian(coordinates, forces=forces, retain_graph=True)
        
        # Calculate the number of elements in the energies, forces, and Hessian tensors
        n_molec += predicted_energies.shape[0] # The number of molecules is equal to the number of elements in the energy tensor
        #print("Number of molecules:", n_molec)
        N = (species >= 0).sum(dim=1, dtype=true_energies.dtype) # N is a tensor with the number of atoms in each molecule
        n_force_elements += sum(3*N).item() # Total number of force elements in the batch
        n_Hessian_elements += sum(9*N**2).item() # Total number of Hessian elements in the batch
        
        # Sum the Mean Squared Errors from the energies, forces and Hessian tensors
        energy_mse += mse_sum(predicted_energies, true_energies).item()
        force_mse += mse_sum(forces, true_forces).item()
        hessian_mse += mse_sum(hessian, true_hessian).item()
        
    # Return the RMSE
    return hartree2kcalmol(math.sqrt(energy_mse / n_molec)), \
           hartree2kcalmol(math.sqrt(force_mse / n_force_elements)), \
           hartree2kcalmol(math.sqrt(hessian_mse / n_Hessian_elements))

In [10]:
def test_rmse():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    energy_mse = 0.0
    force_mse = 0.0
    hessian_mse = 0.0
    n_molec = 0
    n_force_elements = 0
    n_Hessian_elements = 0
    
    for properties in test:
        # Save the properties in variables
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).requires_grad_(True)
        true_energies = properties['energies'].to(device)
        true_forces = properties['forces'].to(device)
        true_hessian = properties['hessian'].to(device)
        
        # Predict energies from our model
        _, predicted_energies = model((species, coordinates))
        
        # Calculate the predicted forces and the Hessian from our model
        forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
        hessian = torchani.utils.hessian(coordinates, forces=forces, retain_graph=True)
        
        # Calculate the number of elements in the energies, forces, and Hessian tensors
        n_molec += predicted_energies.shape[0] # The number of molecules is equal to the number of elements in the energy tensor
        #print("Number of molecules:", n_molec)
        N = (species >= 0).sum(dim=1, dtype=true_energies.dtype) # N is a tensor with the number of atoms in each molecule
        n_force_elements += sum(3*N).item() # Total number of force elements in the batch
        n_Hessian_elements += sum(9*N**2).item() # Total number of Hessian elements in the batch
        
        # Sum the Mean Squared Errors from the energies, forces and Hessian tensors
        energy_mse += mse_sum(predicted_energies, true_energies).item()
        force_mse += mse_sum(forces, true_forces).item()
        hessian_mse += mse_sum(hessian, true_hessian).item()
        
    # Return the RMSE
    return hartree2kcalmol(math.sqrt(energy_mse / n_molec)), \
           hartree2kcalmol(math.sqrt(force_mse / n_force_elements)), \
           hartree2kcalmol(math.sqrt(hessian_mse / n_Hessian_elements))

In the training loop, we need to compute force, Hessian, loss for forces, and loss for Hessian



**Additional** 

Note that forces and Hessian are also computed in the cell below, and the loss is calculated by considering energy error, force error, and Hessian error.

`force_coefficient` and `hessian_coefficient` sets the importance of the forces and the Hessian w.r.t. energies, respectively.

For more details on how different properties are calculated, see `gao2020`, especially Table 1 and example 3/Listing 3.

In [None]:
mse = torch.nn.MSELoss(reduction='none')

print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
# We only train 3 epoches here in able to generate the docs quickly.
# Real training should take much more than 3 epoches.
max_epochs = 3  # or try: config['max_epochs']
early_stopping_learning_rate = 1.0E-5
force_coefficient = 0.08  # controls the importance of energy loss vs force loss
hessian_coefficient = 0.02  # controls the importance of energy loss vs Hessian loss
best_model_checkpoint = 'torch-checkpoint-files/Training-RTP-EFH-best.pt'

# Initialize wandb to keep track of the training
wandb.init(project="project name", entity="user", config=config, tags=["tag 1","tag 2"],
        name='EFH-training') # fill each kwarg with your custom run information
# (id="run_id", resume="must") for resuming a previous run. Replace "rund_id" for yout wandb run ID

for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
    
    
    e_rmse, fc_rmse, hess_rmse = validate()

    print('Energy RMSE:', e_rmse, 'Force RMSE:', fc_rmse, 'and Hessian RMSE:', hess_rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)

    learning_rate = AdamW.param_groups[0]['lr']

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
    if AdamW_scheduler.is_better(e_rmse, AdamW_scheduler.best):
        torch.save(nn.state_dict(), best_model_checkpoint)

    AdamW_scheduler.step(e_rmse)
    SGD_scheduler.step(e_rmse)
    
    wandb.log({'validation_energy_rmse':e_rmse, 'validation_forces_rmse':fc_rmse, 'validation_hessian_rmse':hess_rmse,
               'best_validation_energy_rmse': AdamW_scheduler.best, 
               'learning_rate': learning_rate}, 
              step = AdamW_scheduler.last_epoch + 1)
    
    e_rmse, fc_rmse, hess_rmse = test_rmse()
    wandb.log({'test_energy_rmse':e_rmse, 'test_forces_rmse':fc_rmse, 'test_hessian_rmse':hess_rmse}, 
              step = AdamW_scheduler.last_epoch + 1)
    
    # Besides being stored in x, species and coordinates are also stored in y.
    # So here, for simplicity, we just ignore the x and use y for everything.
    for i, properties in tqdm.tqdm(
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).requires_grad_(True)
        true_energies = properties['energies'].to(device)
        true_forces = properties['forces'].to(device)
        true_hessian = properties['hessian'].to(device)
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates)) # 1 forward pass

        # We can use torch.autograd.grad to compute force and Hessian. 
        # Remember to create graph so that the loss of the force and 
        # the loss of the Hessian can contribute to the gradient of 
        # parameters, and also to retain graph so that we can backward
        # through it a second time when computing gradient w.r.t.
        # parameters.
        
        # Calculate the atomic forces
        forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
        # 1 backward pass
        
        # Calculate the Hessian
        hessian = torchani.utils.hessian(coordinates, forces=forces, retain_graph=True, create_graph=True)
        # 3N backward passes
        
        # Now the total loss has three parts, energy loss, force loss, and Hessian loss
        energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
        force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / (3*num_atoms)).mean()
        hessian_loss = (mse(true_hessian, hessian).sum(dim=(1, 2)) / (9*num_atoms**2)).mean()

        # Loss function
        loss = energy_loss + force_coefficient * force_loss + hessian_coefficient * hessian_loss 
        # 1 backward pass
        
        AdamW.zero_grad()
        SGD.zero_grad()
        loss.backward()
        AdamW.step()
        SGD.step()
        
        # Write current batch loss to WandB
        #wandb.log({'Total loss': loss,'Energy loss': energy_loss, 'Force loss': force_loss, 'Hessian loss': hessian_loss})
    
    # Save the information into a checkpoint file
    torch.save({
        'nn': nn.state_dict(),
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
    }, latest_checkpoint)
    

wandb.finish()