In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ase.io
import torch
import schnetpack

from schnetpack_custom.atoms import MultitaskAtomsData
import data.paths

In [None]:
device = torch.device('cpu')

In [None]:
def predict_on_xyz(model, xyzfile):
    return model.forward(
        schnetpack.data.loader._collate_aseatoms([
            schnetpack.data.atoms.torchify_dict(
                schnetpack.data.atoms._convert_atoms(
                    ase.io.read(xyzfile)
                )
            )
        ])
    )

In [None]:
def model_dir(model_name):
    return os.path.join('/home/cgaul/MaLTOSe2020/schnetpack_exps/models/', model_name, 'best_model')

In [None]:
# For example, load the multiask_model:
model_name = 'multitask_model'
model = torch.load(model_dir(model_name), map_location=device)

## Evaluate model on unified test set

In [None]:
# QM9
dataset_qm9 = MultitaskAtomsData(
    schnetpack.datasets.QM9(data.paths.QM9.db, load_only=['homo', 'lumo', 'gap']), [
        ('HOMO-B3LYP', 'homo'),
        ('LUMO-B3LYP', 'lumo'),
        ('Gap-B3LYP', 'gap')],
    validity_column=False)
_, _, test_qm9 = schnetpack.data.partitioning.train_test_split(
    data=dataset_qm9, split_file=data.paths.QM9.split_v2)

In [None]:
split_file = data.paths.Alchemy.split_v2
dataset = MultitaskAtomsData(
    schnetpack.data.atoms.AtomsData(data.paths.Alchemy.db, load_only=['homo', 'lumo', 'gap']), [
        ('HOMO-B3LYP', 'homo'),
        ('LUMO-B3LYP', 'lumo'),
        ('Gap-B3LYP', 'gap')],
    validity_column=False)
_, _, test_alchemy = schnetpack.data.partitioning.train_test_split(
    data=dataset, split_file=split_file)

In [None]:
split_file = data.paths.OE62.split_v2
dataset = MultitaskAtomsData(
    schnetpack.data.atoms.AtomsData(data.paths.OE62.db), [
        ('HOMO-PBE0', 'homo PBE0_vacuum'),
        ('LUMO-PBE0', 'lumo PBE0_vacuum'),
        ('Gap-PBE0', 'gap PBE0_vacuum')],
    validity_column=False)
_, _, test_oe62 = schnetpack.data.partitioning.train_test_split(
    data=dataset, split_file=split_file)

In [None]:
hopv_split_file = data.paths.HOPV.split_v2
dataset_hopv = MultitaskAtomsData(
    schnetpack.data.atoms.AtomsData(data.paths.HOPV.db), [
        ('HOMO-B3LYP', 'HOMO B3LYP/def2-SVP'),
        ('LUMO-B3LYP', 'LUMO B3LYP/def2-SVP'),
        ('Gap-B3LYP', 'Gap B3LYP/def2-SVP'),
        ('HOMO-PBE0', 'HOMO PBE0/def2-SVP'),
        ('LUMO-PBE0', 'LUMO PBE0/def2-SVP'),
        ('Gap-PBE0', 'Gap PBE0/def2-SVP')],
    validity_column=False)
_, _, test_hopv = schnetpack.data.partitioning.train_test_split(
    data=dataset_hopv, split_file=hopv_split_file)

In [None]:
split_file = data.paths.TABS.split_v2
dataset = MultitaskAtomsData(
    schnetpack.data.atoms.AtomsData(data.paths.TABS.db), [
        ('HOMO-B3LYP', 'homo'),
        ('LUMO-B3LYP', 'lumo'),
        ('Gap-B3LYP', 'gap')],
    validity_column=False)
_, _, test_tabs = schnetpack.data.partitioning.train_test_split(
    data=dataset, split_file=split_file)

In [None]:
datasets = {
    "QM9_test": test_qm9,
    "Alchemy_test": test_alchemy,
    "OE62_test": test_oe62,
    "HOPV_test": test_hopv,
    "TABS": test_tabs,
}

In [None]:
def get_available_properties(model):
    try: # for Set2Set output module
         return [p for om in model.output_modules for p in om.properties]
    except: # for Atomwise output module
        return [om.property for om in model.output_modules]

In [None]:
def evaluate_unified(model, dataset_name, n_points=None, seed=None):
    dataset = datasets[dataset_name]
    batch_size = 10
    
    gen = torch.Generator()
    if seed:
        gen.manual_seed(seed)
    else:
        gen.seed()
    sampler = torch.utils.data.RandomSampler(dataset, replacement=False, generator=gen)
    batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=False)

    ret = {
        'tgt': {p: np.array([]) for p in dataset.available_properties},
        'est': {p: np.array([]) for p in get_available_properties(model=model)},
    }
    test_loader = schnetpack.data.loader.AtomsLoader(dataset, batch_sampler=batch_sampler)
    for i, b in enumerate(test_loader):
        for p in ret['tgt'].keys():
            ret['tgt'][p] = np.append(ret['tgt'][p], b[p])
        b = {k: v.to(device) for k, v in b.items()}
        est = model(b)
        for p in ret['est'].keys():
            ret['est'][p] = np.append(ret['est'][p], est[p].detach().to('cpu').numpy())
        if n_points is not None and (i+1) * batch_size >= n_points:
            break
    return ret

In [None]:
def compute_regular_data(n_points=100, seed=None):
    return {
        dataset_name: evaluate_unified(
            model, dataset_name, n_points, seed=seed) for dataset_name in datasets.keys()}

In [None]:
tgt_est = compute_regular_data(n_points=100, seed=None)

## Evaluate model on data of Kuzmich2017

In [None]:
def evaluate_kuzmich(model):
    ret = {}
    df = pd.read_csv('../Data/Kuzmich2017/table1.csv')
    mapping = {
        'DTDfBTTDPP2': 'DTDfBT(TDPP)2',
        '10_DBFI-MTT': 'DBFI-MTT',
    }
    ambiguous = ['M10']
    for f in sorted(os.listdir('../Data/Kuzmich2017/fixed_xyz/')):
        if f.endswith('.xyz'):
            id = f[3:-13]
            if id in mapping:
                id = mapping[id]
            if id in ambiguous:
                print('id: {} ambiguous'.format(f[3:-13]))
                continue
            lb = f[3:-13]
            xyzfile = os.path.join('../Data/Kuzmich2017/fixed_xyz/', f)
            pred = predict_on_xyz(model, xyzfile)
            est = {k: float(v) for k, v in pred.items()}
            tgt = {
                'LUMO-B3LYP': float(df[df['Acceptor’s Label']==id]['LUMO (eV)'])
            }
            ret[lb] = {
                'tgt': tgt,
                'est': est,
            }
    return ret

In [None]:
tgt_est_kuzmich = evaluate_kuzmich(model)

In [None]:
est_properties = get_available_properties(model=model)

In [None]:
# Plot Kuzmich2017 data alone:
qe = 'LUMO-B3LYP' if 'LUMO-B3LYP' in est_properties else 'LUMO-PBE0'
qt = 'LUMO-B3LYP'
x = [v['tgt'][qt] for v in tgt_est_kuzmich.values()]
y = [v['est'][qe] for v in tgt_est_kuzmich.values()]
plt.scatter(x, y)
plt.axline((np.mean(x), np.mean(x)), slope=1)
plt.xlabel('{} target (eV)'.format(qt))
plt.ylabel('{} estimate (eV)'.format(qe))
plt.title(model_name)
plt.show()
dev = np.array(x) - np.array(y)
print('MAE={:.2f}(eV), RMSE={:.2f}eV'.format(
    np.mean(np.abs(dev)),
    np.sqrt(np.mean(np.square(dev)))))

In [None]:
# Bring data into the regular format:
def add_kuzmich(tgt_est, seed=None):
    random_state = np.random.RandomState(seed=seed)
    tgt_est_kuzmich = evaluate_kuzmich(model)
    # Sort by keys:
    k_data = sorted(list(tgt_est_kuzmich.items()))
    # Shuffle order
    random_state.shuffle(k_data)
    # Drop keys:
    k_data = [v for _, v in k_data]
    ret = {
            'tgt': {p: np.array([]) for p in k_data[0]['tgt'].keys()},
            'est': {p: np.array([]) for p in k_data[0]['est'].keys()},
        }
    for kd in k_data:
        for k in ret['tgt'].keys():
            ret['tgt'][k] = np.append(ret['tgt'][k], [kd['tgt'][k]])
        for k in ret['est'].keys():
            ret['est'][k] = np.append(ret['est'][k], [kd['est'][k]])
    tgt_est['Kuzmich2017'] = ret

In [None]:
add_kuzmich(tgt_est, seed=None)

## Streamlined evaluation and plotting

In [None]:
RANDOMSEED = 26463461

In [None]:
model_name = 'multitask_model_v08'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
model_name = 'multitask_model_v06_prelim'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
model_name = 'multitask_model_only_b3lyp'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
model_name = 'multitask_model'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
model_name = 'multitask_model_v05'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
model_name = 'multitask_model_only_pbe0'
model = torch.load(model_dir(model_name), map_location=device)

In [None]:
est_properties = get_available_properties(model=model)
tgt_est = compute_regular_data(n_points=200, seed=RANDOMSEED)
add_kuzmich(tgt_est, seed=RANDOMSEED)

## Analyse and plot

In [None]:
# Define a fixed color code for each test set
for k, color in {
    'QM9_test': 'orange',
    'Alchemy_test': 'red',
    'OE62_test': 'purple',
    'HOPV_test': 'blue',
    'TABS': 'green',
    'Kuzmich2017': 'black'
}.items():
    if k in tgt_est:
        tgt_est[k]['color'] = color

In [None]:
properties = ['HOMO', 'LUMO', 'Gap']
theories = ['B3LYP', 'PBE0']

In [None]:
def make_plot(tgt_est, qt_tgt, qt_est, n_points=-1, skiptests=[]):
    plot_empty = True
    plt.figure(figsize=(5, 5))
    plotname = '{}-{}'.format(model_name, qt_tgt);
    if qt_est != qt_tgt:
        plotname += '-cross'
    for dataset_name, te in tgt_est.items():
        try:
            x = te['tgt'][qt_tgt]
            y = te['est'][qt_est]
        except:
            continue
        if dataset_name in skiptests:
            # Add a tag, but only if skipped due to skiptests: 
            plotname = '{}-skip{}'.format(plotname, dataset_name)
            continue
        dev = np.array(x) - np.array(y)
        mae = np.mean(np.abs(dev))
        rmse = np.sqrt(np.mean(np.square(dev)))
        print('{}: MAE={:.2f}(eV), RMSE={:.2f}eV'.format(dataset_name, mae, rmse))
        plt.scatter(x[:n_points], y[:n_points], color=te['color'], label='{dataset} (MAE={mae:.2f}eV)'.format(
            dataset=dataset_name, mae=mae))
        plt.axline((np.mean(x), np.mean(x)), slope=1)
        plt.xlabel('{} target (eV)'.format(qt_tgt))
        plt.ylabel('{} estimate (eV)'.format(qt_est))
        plot_empty = False
    if plot_empty:
        print("{}/{} empty for {}.".format(qt_tgt, qt_est, model_name))
    else:
        plt.title(model_name)
        plt.grid()
        plt.legend()
        plt.savefig('{}.png'.format(plotname), dpi=200)
        plt.show()

In [None]:
# Target-estimate plots for each property and theory (diagonal and cross)
n_points = 25
for skiptests in [[], ['TABS']]:
    for a in properties:
        for t in theories:
            assert len(theories)==2
            t_cross = [th for th in theories if th != t][0]
            q = a + '-' + t
            q_cross = a + '-' + t_cross
            plt.rcParams.update({'axes.facecolor': 'lightgray'})
            make_plot(tgt_est, q, q_cross, n_points=n_points, skiptests=skiptests)
            plt.rcParams.update({'axes.facecolor': 'white'})
            make_plot(tgt_est, q, q, n_points=n_points, skiptests=skiptests)