# Learning to Train with the New SPK
A few APIs have changed for the better

In [1]:
from fff.learning.spk import ase_to_spkdata
from schnetpack import train as trn
import schnetpack as spk
from random import shuffle
from tqdm import tqdm
from ase.db import connect
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Create the Data Loader
We'll connect to the dataset I made via an external script and use the "same data" 

Load in the initial training set and convert it to spk training data

In [2]:
with connect('../initial-database/initial-psi4-631g.db') as db:
    all_strcs = [a.toatoms() for a in db.select('')]
shuffle(all_strcs)

Save the first 10% as validation

In [3]:
cut = len(all_strcs) // 10

In [4]:
val_data = ase_to_spkdata(all_strcs[:cut], 'data/valid.db')

In [5]:
data = ase_to_spkdata(all_strcs[cut:], 'data/train.db')

Make the laoders

In [6]:
loader = spk.AtomsLoader(data, batch_size=64, shuffle=True, num_workers=8, drop_last=True)

In [7]:
val_loader = spk.AtomsLoader(val_data, batch_size=64, num_workers=8)

## Load the Model
Get the one we've already converted

In [8]:
model = torch.load('starting-psi4-model', map_location='cuda')

## Train the model
Use both the energy and the forces

First define the loss function

In [9]:
# tradeoff
rho_tradeoff = 0.9

# loss function
def loss(batch, result):
    # compute the mean squared error on the energies
    diff_energy = batch['energy']-result['energy']
    err_sq_energy = torch.mean(diff_energy ** 2)

    # compute the mean squared error on the forces
    diff_forces = batch['forces']-result['forces']
    err_sq_forces = torch.mean(diff_forces ** 2)

    # build the combined loss function
    err_sq = rho_tradeoff*err_sq_energy + (1-rho_tradeoff)*err_sq_forces

    return err_sq

Now the optimizer

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

Now the training metrics

In [11]:
# set up metrics
metrics = [
    spk.metrics.MeanAbsoluteError('energy'),
    spk.metrics.MeanAbsoluteError('forces')
]

# construct hooks
hooks = [
    trn.CSVHook(log_path='test', metrics=metrics),
    trn.ReduceLROnPlateauHook(
        optimizer,
        patience=1, factor=0.8, min_lr=1e-6,
        stop_after_min=True
    )
]

Run the fitting

In [12]:
trainer = trn.Trainer(
    model_path='test',
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=loader,
    validation_loader=loader,
)

In [13]:
trainer.train(device='cuda', n_epochs=128)