# Train a SchNet Model
Generate the deep learning models to be created.

For each model, we need to store the architecture and the outputs of the network. The network designs will be stored in `networks/<name>` where `<name>` is the name of the model.

In [1]:
from schnetpack.atomistic import Atomwise, AtomisticModel, DeltaLearning
from schnetpack.representation import SchNet
from schnetpack import nn
from jcesr_ml.utils import get_atomref_vector
import pickle as pkl
import numpy as np
import shutil
import torch
import json
import os

Default parameters

In [2]:
max_size = 117232
n_atom_basis = 256
n_filters = 256
cutoff=5
n_gaussians=25
max_z=10

## Load in Pre-Trained Models
Some of the models use another, fully-trained model during construction.

Load in the `u0` model

In [3]:
best_u0_model_path = os.path.join('networks', 'u0', str(max_size))

In [4]:
best_u0_model = None
if os.path.isfile(os.path.join(best_u0_model_path, 'finished')):
    # Model architecture
    best_u0_model = torch.load(os.path.join(best_u0_model_path, '..', 'architecture.pth'), map_location='cpu') 
    best_u0_model.output_modules.return_contributions = True
    
    # Model weights
    best_u0_weights = torch.load(os.path.join(best_u0_model_path, 'best_model'), map_location='cpu')
    best_u0_model.load_state_dict(best_u0_weights)
else:
    print('u0 model not yet trained. Re-run this notebook later')

## Create Utility Operations
For saving the models in the require format.

In [5]:
def save_model(output, name, output_props=['g4mp2_0k'], overwrite=False, delta=None):
    """Save a model to disk
    
    Args:
        output (AtomisticModel): Model to be saved
        name (string): Name of model
        output_props ([string]): List of output properties
        overwrite (bool): Whether to overwrite existing models
        delta (str): Baseline property (None if not delta)
    """
    
    # Get the output directory
    out_dir = os.path.join('networks', name)
    if os.path.isdir(out_dir):
        if overwrite:
            shutil.rmtree(out_dir)
        else:
            print('Model already saved. Skipping.')
            return
    os.makedirs(out_dir)
        
    # Save the model
    torch.save(output, os.path.join(out_dir, 'architecture.pth'))
    
    # If needed, save the atomrefs (used when setting the mean and std of training set)
    output_mods = model.output_modules if isinstance(model, AtomisticModel) else model[-1].output_modules
    output_mods = output_mods[-1] if isinstance(output_mods, torch.nn.Sequential) else output_mods
    if output_mods.atomref is not None:
        weights = output_mods.atomref.weight.detach().numpy()
        np.save(os.path.join(out_dir, 'atomref.npy'), weights)

    # Save the training details
    with open(os.path.join(out_dir, 'options.json'), 'w') as fp:
        json.dump({
            'output_props': output_props,
            'delta': delta
        }, fp)

## `U0` Model
The model trained on `u0`, to recreate the SchNet Paper.

In [6]:
reps = SchNet(n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=max_z)
output = Atomwise(n_atom_basis, atomref=get_atomref_vector('b3lyp', max_z), train_embeddings=True)  
model = AtomisticModel(reps, output)

In [7]:
save_model(model, 'u0', ['u0'])

Model already saved. Skipping.


## G4MP2 Model
Same architecture as `u0`, but using `g4mp2_0k` as the target variable. Also need to update atomrefs

In [8]:
model.output_modules = Atomwise(n_atom_basis, atomref=get_atomref_vector('g4mp2', max_z), train_embeddings=True)  

In [9]:
save_model(model, 'g4mp2')

Model already saved. Skipping.


## G4MP2 Transfer-Learned Model
Use the weights from the `u0` model as a starting point

In [10]:
if best_u0_model is not None:
    model.load_state_dict(best_u0_model.state_dict())
    save_model(model, 'g4mp2-transfer')
else:
    print('Model not created because u0 has not finished training. Run this cell again later')

## G4MP2 with B3LYP Charges in Embedding
Use the charges from B3LYP as additional features in the representation for each atom. 

In [11]:
reps = SchNet(n_atom_basis=n_atom_basis-1, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=10,
             additional_features=['atomic_charges'])
output = Atomwise(n_atom_basis, atomref=get_atomref_vector('g4mp2'), train_embeddings=True)
model = AtomisticModel(reps, output)

In [12]:
save_model(model, 'g4mp2-charges', overwrite=True)

## G4MP2 with B3LYP Charges in Output Network
Use the charges from B3LYP as additional features as inputs into the output network of the model

In [13]:
reps = SchNet(n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=10)

In [14]:
out_net = torch.nn.Sequential(
    nn.base.GetRepresentationAndProperties(['atomic_charges']),
    nn.blocks.MLP(256+1, 1, None, 2, nn.activations.shifted_softplus)
)
output = Atomwise(outnet=out_net, atomref=get_atomref_vector('g4mp2'), train_embeddings=True)

In [15]:
model = AtomisticModel(reps, output)

In [16]:
save_model(model, 'g4mp2-charges-in-outnet')

Model already saved. Skipping.


## G4MP2 with Multi-Task Learning
Include additional features in the output layer, to see whether the model learns a more generalizable representation

In [17]:
additional_tasks = ['u0', 'homo', 'lumo', 'zpe']

In [18]:
multitask_rfs = np.hstack((get_atomref_vector('g4mp2'), 
                           get_atomref_vector('b3lyp'),
                           np.zeros((11, len(additional_tasks)-1))))

In [19]:
reps = SchNet(n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=10)
output = Atomwise(256, n_out=1+len(additional_tasks), atomref=multitask_rfs,
                 mean=torch.Tensor([0]*(1+len(additional_tasks))),
                 stddev=torch.Tensor([1]*(1+len(additional_tasks))),
                 train_embeddings=True)
model = AtomisticModel(reps, output)

In [20]:
save_model(model, 'g4mp2-multitask', ['g4mp2_0k'] + additional_tasks)

Model already saved. Skipping.


## G4MP2 $\Delta$-Learning Model
Train model to predict the difference between `u0` and `g4mp2`

In [21]:
reps = SchNet(n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=10)
output = DeltaLearning('u0', False, n_in=256, 
                       atomref=get_atomref_vector('g4mp2') - get_atomref_vector('b3lyp'),
                       return_contributions=True, train_embeddings=True)
model = AtomisticModel(reps, output)

In [22]:
save_model(model, 'g4mp2-delta', delta='u0')

Model already saved. Skipping.


## Stacked $\Delta$-Learning Model
Use atom-wise energies predicted from the `u0` model, rather than the actual `u0` energy. Unlike the `g4mp2-delta` model, this model does not require a B3LYP energy to be computed first

If possible, make the model

In [23]:
if best_u0_model is not None:
    reps = SchNet(n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=6,
              cutoff=5, n_gaussians=n_gaussians, max_z=10)  # Make the representation used in the network
    output = DeltaLearning('u0_yi', True, n_in=n_atom_basis, 
                           atomref=get_atomref_vector('g4mp2') - get_atomref_vector('b3lyp'),
                           train_embeddings=True)
    model = torch.nn.Sequential(
        nn.base.StackedOutputModel(best_u0_model, 'u0'),
        AtomisticModel(reps, output)
    )
    save_model(model, 'g4mp2-stacked-delta', delta='u0')
else: 
    print('Model not created because u0 has not finished training. Run this cell again later')