# Learning to Train with the New SPK
A few APIs have changed for the better

In [1]:
from schnetpack import train as trn
import schnetpack as spk
from tqdm import tqdm
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Create the Data Loader
We'll connect to the dataset I made via an external script and use the "same data" 

In [2]:
data = spk.AtomsData(
    './data/test.db',
)

In [3]:
loader = spk.AtomsLoader(data, batch_size=64, shuffle=True, num_workers=8)

Get the means, which we'll need later

In [4]:
%%time
means, stddevs = loader.get_statistics(
    'energy', divide_by_atoms=True
)

CPU times: user 9.36 s, sys: 5.81 s, total: 15.2 s
Wall time: 31.2 s


## Make the Model
Use some default settings

In [5]:
n_features = 32

schnet = spk.representation.SchNet(
    n_atom_basis=n_features,
    n_filters=n_features,
    n_gaussians=25,
    n_interactions=3,
    cutoff=5.,
    cutoff_network=spk.nn.cutoff.CosineCutoff
)

We need both the forces and energy as outputs

In [6]:
energy_model = spk.atomistic.Atomwise(
    n_in=n_features,
    property='energy',
    mean=means['energy'],
    stddev=stddevs['energy'],
    derivative='forces',
    negative_dr=True
)

Combine it all together

In [7]:
model = spk.AtomisticModel(representation=schnet, output_modules=energy_model)

## Train the model
Use both the energy and the forces

First define the loss function

In [8]:
# tradeoff
rho_tradeoff = 0.1

# loss function
def loss(batch, result):
    # compute the mean squared error on the energies
    diff_energy = batch['energy']-result['energy']
    err_sq_energy = torch.mean(diff_energy ** 2)

    # compute the mean squared error on the forces
    diff_forces = batch['forces']-result['forces']
    err_sq_forces = torch.mean(diff_forces ** 2)

    # build the combined loss function
    err_sq = rho_tradeoff*err_sq_energy + (1-rho_tradeoff)*err_sq_forces

    return err_sq

Now the optimizer

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

Now the training metrics

In [10]:
# set up metrics
metrics = [
    spk.metrics.MeanAbsoluteError('energy'),
    spk.metrics.MeanAbsoluteError('forces')
]

# construct hooks
hooks = [
    trn.CSVHook(log_path='test', metrics=metrics),
    trn.ReduceLROnPlateauHook(
        optimizer,
        patience=5, factor=0.8, min_lr=1e-6,
        stop_after_min=True
    )
]

Run the fitting

In [11]:
trainer = trn.Trainer(
    model_path='test',
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=loader,
    validation_loader=loader,
)

In [12]:
trainer.train(device='cuda', n_epochs=2)