In [None]:
import numpy as np
import pandas as pd
from ase.io.xyz import read_xyz
from ase.io import read
from io import StringIO
import os
import schnetpack as spk
from torch.optim import Adam
import matplotlib.pyplot as plt
import schnetpack.train as trn
from schnetpack import AtomsData
import torch
from sklearn.model_selection import KFold

In [None]:
freesolv_file = 'data/FreeSolv_with_3D.csv'
freesolv_data = pd.read_csv(freesolv_file)

In [None]:
freesolvmod = "./FreeSolvModel"
if not os.path.exists('FreeSolvModel'):
    os.makedirs(freesolvmod)

In [None]:
atoms = freesolv_data['xyz'].map(lambda x: next(read_xyz(StringIO(x), slice(None))))
atoms = [a for a in atoms]

In [None]:
freesolv_expt = np.array(freesolv_data["expt"],dtype=float)

In [None]:
property_list = []
for f in freesolv_expt:
    
    property_list.append(
        {'expt': float(f)}
    )

print('Properties:', property_list)

In [None]:
new_dataset = AtomsData(os.path.join(freesolvmod, 'FreeSolv_SchNet_dataset.db'), available_properties=['expt'])
new_dataset.add_systems(atoms, property_list)

In [None]:
print('Number of reference calculations:', len(new_dataset))
print('Available properties:')

for p in new_dataset.available_properties:
    print('-', p)
print()

example = new_dataset[0]
print('Properties of molecule with id 0:')

for k, v in example.items():
    print('-', k, ':', v.shape)

In [None]:
train, val, test = spk.train_test_split(
        data=new_dataset,
        num_train=500,
        num_val=100,
        split_file=None#os.path.join(freesolvmod, "freesolv_split.npz"),
    )

In [None]:
train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
val_loader = spk.AtomsLoader(val, batch_size=100)

In [None]:
schnet = spk.representation.SchNet(
    n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
    cutoff=4., cutoff_network=spk.nn.cutoff.CosineCutoff
)

In [None]:
output = spk.atomistic.Atomwise(n_in=30, property='expt')

model = spk.AtomisticModel(representation=schnet, output_modules=output)

In [None]:
optimizer = Adam(model.parameters(), lr=1e-2)

In [None]:
loss = trn.build_mse_loss(['expt'])

metrics = [spk.metrics.MeanAbsoluteError('expt')]

hooks = [
    trn.CSVHook(log_path=freesolvmod, metrics=metrics),
    trn.ReduceLROnPlateauHook(
        optimizer,
        patience=5, factor=0.8, min_lr=1e-6,
        stop_after_min=True
    )
]

trainer = trn.Trainer(
    model_path=freesolvmod,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
device = "cpu" # change to 'cpu' if gpu is not available, change to cuda if gpu is
n_epochs = 25 # takes about 10 min on a notebook GPU. reduces for playing around

print('training')
trainer.train(device=device, n_epochs=n_epochs)



In [None]:
best_model = torch.load(os.path.join(freesolvmod, 'best_model'))

train, val, test = spk.train_test_split(
        data=new_dataset,
        split_file=os.path.join(hivmod, "split.npz"),
    )

test_loader = spk.AtomsLoader(test, batch_size=100)

energy_error = 0.0
forces_error = 0.0

for count, batch in enumerate(test_loader):
    # move batch to GPU, if necessary
    batch = {k: v.to(device) for k, v in batch.items()}

    # apply model
    pred = best_model(batch)

    # calculate absolute error of energies
    tmp_energy = torch.sum(torch.abs(pred['expt'] - batch['expt']))
    tmp_energy = tmp_energy.detach().cpu().numpy() # detach from graph & convert to numpy
    energy_error += tmp_energy

    # calculate absolute error of forces, where we compute the mean over the n_atoms x 3 dimensions
    tmp_forces = torch.sum(
        torch.mean(torch.abs(pred[MD17.forces] - batch[MD17.forces]), dim=(1,2))
    )
    tmp_forces = tmp_forces.detach().cpu().numpy() # detach from graph & convert to numpy
    forces_error += tmp_forces

    # log progress
    percent = '{:3.2f}'.format(count/len(test_loader)*100)
    print('Progress:', percent+'%'+' '*(5-len(percent)), end="\r")

energy_error /= len(test)
forces_error /= len(test)

print('\nTest MAE:')
print('    energy: {:10.3f} kcal/mol'.format(energy_error))
print('    forces: {:10.3f} kcal/mol/\u212B'.format(forces_error))