In [1]:
import os

import numpy as onp
import jax.numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import GridSearchCV

from dipole import VectorValuedKRR
from utils import matern, coulomb

In [2]:
data = np.load('data/HOOH.DFT.PBE-TS.light.MD.500K.50k.R_E_F_D_Q.npz')
X = np.array(data['R'])
y = np.array(data['D'])



In [3]:
import schnetpack as spk
from ase.db.jsondb import JSONDatabase
from schnetpack.data import AtomsData
atomsdb = AtomsData('data/data.json')
print(len(atomsdb))

2000


In [4]:
import torch
import schnetpack as spk

# loss function
def squared_norm(batch, result):
    diff = batch['dipole'].squeeze() - result['dipole'].squeeze()
    err_sq = torch.linalg.norm(diff, axis=1)**2
    return err_sq.mean().item()


def train_schnet(train, val, size=50):

    train_loader = spk.AtomsLoader(train, batch_size=2048, shuffle=True)
    val_loader = spk.AtomsLoader(val, batch_size=2048)

    schnet = spk.representation.SchNet(
        n_atom_basis=30, n_filters=30, n_gaussians=20, n_interactions=5,
        cutoff=4., cutoff_network=spk.nn.cutoff.CosineCutoff
    )

    output_alpha = spk.atomistic.DipoleMoment(n_in=30, property='dipole')
    model = spk.AtomisticModel(representation=schnet, output_modules=output_alpha)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    loss = spk.train.build_mse_loss(['dipole'])

    metrics = [spk.metrics.MeanAbsoluteError('dipole')]
    hooks = [
        spk.train.ReduceLROnPlateauHook(
            optimizer,
            patience=5, factor=0.8, min_lr=1e-6,
            stop_after_min=True
        )
    ]

    model_dir = f'model-{size}'
    trainer = spk.train.Trainer(
        model_path=model_dir,
        model=model,
        hooks=hooks,
        loss_fn=loss,
        optimizer=optimizer,
        train_loader=train_loader,
        validation_loader=val_loader,
    )

    trainer.train(device='cpu', n_epochs=20)

    return torch.load(os.path.join(model_dir, 'best_model'))

In [5]:
data_subset_sizes = sigma_choices = list(onp.linspace(10, 100, 10, dtype=int))
parameters = {'sigma': sigma_choices}

start, stop = 1900, 2000
test_data = atomsdb.create_subset(list(range(start, stop)))
test_batch = list(spk.AtomsLoader(test_data, batch_size=2048))[0]
test = slice(start, stop)
Xtest, ytest = X[test], y[test]
errors_gdml, errors_schnet = [], []

for size in data_subset_sizes:

    indices = onp.random.choice(stop, size=size, replace=False)
    Xtrain, ytrain = X[indices], y[indices]

    cross_validation = GridSearchCV(VectorValuedKRR(), parameters)
    cross_validation.fit(Xtrain, ytrain)
    results = cross_validation.cv_results_
    best = np.argmin(results['rank_test_score'])
    best_params = results['params'][best]
    print(f'best params: {best_params}')
    best_model = VectorValuedKRR(**best_params)
    best_model.fit(Xtrain, ytrain)
    best_test_error = -best_model.score(Xtest, ytest).item()
    errors_gdml.append(best_test_error)
    print(f'gdml error: {best_test_error:.6f}')

    data_cut = atomsdb.create_subset(indices)
    val_size = size // 10
    train, val, test = spk.train_test_split(
        data=data_cut,
        num_train=size - val_size,
        num_val=val_size
    )
    best_model = train_schnet(train, val, size=size)
    prediction = best_model(test_batch)
    test_error_schnet = squared_norm(prediction, test_batch)
    print(f'schnet error: {test_error_schnet:.6f}')
    errors_schnet.append(test_error_schnet)




best params: {'sigma': 10}
gdml error: 0.158623
schnet error: 0.019171
best params: {'sigma': 80}
gdml error: 0.157207
schnet error: 0.007129
best params: {'sigma': 10}
gdml error: 0.157602
schnet error: 0.023377
best params: {'sigma': 10}
gdml error: 0.175012
schnet error: 0.187183
best params: {'sigma': 10}
gdml error: 0.179518
schnet error: 0.095198
best params: {'sigma': 90}
gdml error: 0.177824
schnet error: 0.160959
best params: {'sigma': 70}
gdml error: 0.159283
schnet error: 0.068554
best params: {'sigma': 10}
gdml error: 0.172675
schnet error: 0.034088
best params: {'sigma': 40}
gdml error: 0.184032
schnet error: 0.072784
best params: {'sigma': 50}
gdml error: 0.184105
schnet error: 0.078543
