In [101]:
%load_ext autoreload
%autoreload 2
import torch
import os
import math
import torch.utils.tensorboard
import tqdm
import pickle
import nci
from sklearn.model_selection import train_test_split
import torchani
import numpy as np
from torchani.aev import AEVComputer
# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [102]:
dataset = '../NCIA_HB375x10'
elements = nci.dataset_elements[dataset]
ani1x_values = {'radial_cutoff': 5.2,
                'angular_cutoff': 3.5,
                'radial_eta': 16.0,
                'angular_eta': 8.0,
                'radial_dist_divisions': 16,
                'angular_dist_divisions': 4,
                'zeta': 32.0,
                'angle_sections': 8,
                'num_species': len(elements)}

constants = nci.get_constants(**ani1x_values)
data = nci.load_data(dataset)
random_state = 0
train_size = 0.8
batch_size = 10
training, validation = train_test_split(data, train_size=train_size, random_state=random_state)
trainloader, validloader = nci.get_data_loaders(training, validation, batch_size=batch_size)
aev_computer = AEVComputer(**constants)
###############################################################################
# When iterating the dataset, we will get a dict of name->property mapping
#
################################# a ##############################################
# Now let's define atomic neural networks.
aev_dim = aev_computer.aev_length

H_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 160),
    torch.nn.CELU(0.1),
    torch.nn.Linear(160, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1),
)

C_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 144),
    torch.nn.CELU(0.1),
    torch.nn.Linear(144, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

N_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

O_network = torch.nn.Sequential(
    torch.nn.Linear(aev_dim, 128),
    torch.nn.CELU(0.1),
    torch.nn.Linear(128, 112),
    torch.nn.CELU(0.1),
    torch.nn.Linear(112, 96),
    torch.nn.CELU(0.1),
    torch.nn.Linear(96, 1)
)

nn = torchani.ANIModel([H_network, C_network, N_network, O_network])


In [103]:
# Initialize the weights and biases.
#   Pytorch default initialization for the weights and biases in linear layers
#   is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
#   We initialize the weights similarly but from the normal distribution.
#   The biases were initialized to zero.

def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        torch.nn.init.zeros_(m.bias)

nn.apply(init_params)

ANIModel(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=160, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=160, out_features=128, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=128, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (1): Sequential(
    (0): Linear(in_features=384, out_features=144, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=144, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features=96, out_features=1, bias=True)
  )
  (2): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): CELU(alpha=0.1)
    (2): Linear(in_features=128, out_features=112, bias=True)
    (3): CELU(alpha=0.1)
    (4): Linear(in_features=112, out_features=96, bias=True)
    (5): CELU(alpha=0.1)
    (6): Linear(in_features

In [104]:
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torchani.nn.Sequential(aev_computer, nn).to(device)

In [105]:
# batch = next(iter(trainloader))
# ids, labels, species_coordinates, index_diff = batch

# # model(species_coordinates, index_diff)
# species, coordinates = species_coordinates
# energies = np.array([entry['energies'] for entry in data])
# energies
# energies.var()
# species, aev = aev_computer.forward(species_coordinates, index_diff)

In [112]:
###############################################################################
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
# to updates the weights and Stochastic Gradient Descent (SGD) to update the biases.
# Moreover, we need to specify different weight decay rate for different layes.
#
# .. note::
#
#   The weight decay in `inputtrain.ipt`_ is named "l2", but it is actually not
#   L2 regularization. The confusion between L2 and weight decay is a common
#   mistake in deep learning.  See: `Decoupled Weight Decay Regularization`_
#   Also note that the weight decay only applies to weight in the training
#   of ANI models, not bias.
#
# .. _Decoupled Weight Decay Regularization:
#   https://arxiv.org/abs/1711.05101


AdamW = torch.optim.AdamW([
    # H networks
    {'params': [H_network[0].weight]},
    {'params': [H_network[2].weight], 'weight_decay': 0.00001},
    {'params': [H_network[4].weight], 'weight_decay': 0.000001},
    {'params': [H_network[6].weight]},
    # C networks
    {'params': [C_network[0].weight]},
    {'params': [C_network[2].weight], 'weight_decay': 0.00001},
    {'params': [C_network[4].weight], 'weight_decay': 0.000001},
    {'params': [C_network[6].weight]},
    # N networks
    {'params': [N_network[0].weight]},
    {'params': [N_network[2].weight], 'weight_decay': 0.00001},
    {'params': [N_network[4].weight], 'weight_decay': 0.000001},
    {'params': [N_network[6].weight]},
    # O networks
    {'params': [O_network[0].weight]},
    {'params': [O_network[2].weight], 'weight_decay': 0.00001},
    {'params': [O_network[4].weight], 'weight_decay': 0.000001},
    {'params': [O_network[6].weight]},
])

SGD = torch.optim.SGD([
    # H networks
    {'params': [H_network[0].bias]},
    {'params': [H_network[2].bias]},
    {'params': [H_network[4].bias]},
    {'params': [H_network[6].bias]},
    # C networks
    {'params': [C_network[0].bias]},
    {'params': [C_network[2].bias]},
    {'params': [C_network[4].bias]},
    {'params': [C_network[6].bias]},
    # N networks
    {'params': [N_network[0].bias]},
    {'params': [N_network[2].bias]},
    {'params': [N_network[4].bias]},
    {'params': [N_network[6].bias]},
    # O networks
    {'params': [O_network[0].bias]},
    {'params': [O_network[2].bias]},
    {'params': [O_network[4].bias]},
    {'params': [O_network[6].bias]},
], lr=1e-3)

###############################################################################
# Setting up a learning rate scheduler to do learning rate decay
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=100, threshold=0)
SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(SGD, factor=0.5, patience=100, threshold=0)


In [113]:
from typing import Collection, Dict, List, Optional, Tuple, Union

def train(
    model,
    loss_function,
    trainloader,
    testloader,
    epochs: int = 15,
    savepath: Optional[str] = None,
    idx: Optional[int] = None,
    device=None,
) -> Tuple[List[float], List[float]]:
    """
    Train model.
    Parameters
    ----------
    model
        Model
    optimizer
        Optimizer
    loss_function
        Loss function
    AEVC: torchani.AEVComputer
        AEVComputer
    trainloader:
        Train set loader
    testloader:
        Test (validation) set loader
    epochs: int
        Number of training epochs
    savepath:
        Save path for best performing model
    idx: int
        Inded (for multiple trainings)
    device:
        Computation device
    Returns
    -------
    Tuple[List[float], List[float]]
        Train loss and validation loss
    """

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Move model and AEVComputer to device
    model.to(device)

    train_losses: List[float] = []
    valid_losses: List[float] = []

    best_valid_loss = np.inf

    for epoch in tqdm.trange(epochs, desc="Training"):

        # Model in training mode
        model.train()

        epoch_loss: float = 0.0

        # Training
        for _, labels, species_coordinates, index_diff in trainloader:

            # Move data to device
            labels = labels.to(device)
            species = species_coordinates[0].to(device)
            coordinates = species_coordinates[1].to(device)
            index_diff = index_diff.to(device)


            _, predicted_energies = model((species, coordinates), index_diff)
            predicted_energies = predicted_energies.to(torch.float64)

            loss = loss_function(predicted_energies, labels)


            # Update weights
            AdamW.zero_grad()
            SGD.zero_grad()
            loss.backward()
            AdamW.step()
            SGD.step()


            epoch_loss += loss.item()

        else:
            valid_loss: float = 0.0

            # Model in evaluation mode
            model.eval()

            # Validation
            with torch.no_grad():
                for _, labels, species_coordinates, index_diff in testloader:

                    # Move data to device
                    labels = labels.to(device)
                    species = species_coordinates[0].to(device)
                    coordinates = species_coordinates[1].to(device)
                    index_diff = index_diff.to(device)

                    # Forward pass
                    _, predicted_energies = model((species, coordinates), index_diff)
                    predicted_energies = predicted_energies.to(torch.float64)

                    valid_loss += loss_function(predicted_energies, labels).item()

            # Normalise losses
            epoch_loss /= len(trainloader)
            valid_loss /= len(testloader)

            # Save best model
            if valid_loss < best_valid_loss and savepath is not None:
                # TODO: Save Optimiser
                if idx is None:
                    modelname = os.path.join(savepath, "best.pth")
                else:
                    modelname = os.path.join(savepath, f"best_{idx}.pth")

                best_epoch = epoch
                best_valid_loss = valid_loss

            train_losses.append(epoch_loss)
            valid_losses.append(valid_loss)
            print(epoch_loss, valid_loss)


    return train_losses, valid_losses

In [114]:
# Train model
mse = torch.nn.MSELoss()
epochs = 40
train_losses, valid_losses = train(model, mse, trainloader, validloader, epochs=epochs, savepath='./')

Training:   2%|▎         | 1/40 [00:04<03:11,  4.92s/it]

nan nan


Training:   5%|▌         | 2/40 [00:09<03:04,  4.85s/it]

nan nan


Training:   8%|▊         | 3/40 [00:14<02:58,  4.81s/it]

nan nan


Training:   8%|▊         | 3/40 [00:16<03:22,  5.48s/it]


KeyboardInterrupt: 

In [67]:
valid_losses

[10.345996567549538,
 10.135153795757192,
 10.017717671913946,
 10.0829897331543,
 9.934686945138225,
 10.123971644934258,
 9.830558752681853,
 9.905816489287313,
 9.880975896356455,
 9.728642803488265,
 9.880769683113169,
 9.862523204394945,
 10.153111565354525,
 9.99571707568444,
 9.940830473949921]

In [68]:
train_losses

[10.645850110657591,
 9.93084139366282,
 9.853266364283586,
 9.81873092948209,
 9.726423168743883,
 9.724835540207359,
 9.682365661474405,
 9.671551871652099,
 9.61286624174033,
 9.649055442748033,
 9.630369121647812,
 9.604778037375649,
 9.576343490779585,
 9.57597172225012,
 9.56329573008717]

In [19]:
###############################################################################
# Train the model by minimizing the MSE loss, until validation RMSE no longer
# improves during a certain number of steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a threshold.
#
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state.
latest_checkpoint = 'latest.pt'

###############################################################################
# Resume training from previously saved checkpoints:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    SGD.load_state_dict(checkpoint['SGD'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])
    SGD_scheduler.load_state_dict(checkpoint['SGD_scheduler'])

###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint



In [None]:
def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            species = properties['species'].to(device)
            coordinates = properties['coordinates'].to(device).float()
            true_energies = properties['energies'].to(device).float()
            _, predicted_energies = model((species, coordinates))
            total_mse += mse_sum(predicted_energies, true_energies).item()
            count += predicted_energies.shape[0]
    model.train(True)
    return hartree2kcalmol(math.sqrt(total_mse / count))


###############################################################################
# We will also use TensorBoard to visualize our training process
tensorboard = torch.utils.tensorboard.SummaryWriter()

###############################################################################
# Finally, we come to the training loop.
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
mse = torch.nn.MSELoss(reduction='none')

print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 10
early_stopping_learning_rate = 1.0E-5
best_model_checkpoint = 'best.pt'

for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
    rmse = validate()
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)

    learning_rate = AdamW.param_groups[0]['lr']

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
        torch.save(nn.state_dict(), best_model_checkpoint)

    AdamW_scheduler.step(rmse)
    SGD_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

    for i, properties in tqdm.tqdm(
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))

        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()

        AdamW.zero_grad()
        SGD.zero_grad()
        loss.backward()
        AdamW.step()
        SGD.step()

        # write current batch loss to TensorBoard
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)

    torch.save({
        'nn': nn.state_dict(),
        'AdamW': AdamW.state_dict(),
        'SGD': SGD.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
        'SGD_scheduler': SGD_scheduler.state_dict(),
    }, latest_checkpoint)
