In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ase.io
import torch
import schnetpack

from maltose.atoms import MultitaskAtomsData

import multitask_data
import evaluation

In [None]:
device = torch.device('cpu')
base_dir = '..'
data_base_dir = os.path.join(base_dir, 'data')

In [None]:
def model_dir(model_name):
    return os.path.join(base_dir, 'models', model_name)

In [None]:
import importlib
def load_model(model_name):
    dir = model_dir(model_name)
    files = os.listdir(dir)
    if 'best_model_state.pth' in files:
        module_name = 'configs.{}'.format(model_name)
        print('Importing module {}...'.format(module_name))
        config = importlib.import_module(module_name)
        model = config.build_model()
        path = os.path.join(dir, 'best_model_state.pth')
        print('Loading', path)
        model.load_state_dict(torch.load(path))
        model.eval()
        return model
    elif 'best_model' in files:
        path = os.path.join(model_dir(model_name), 'best_model')
        print('Loading', path)
        model = torch.load(path, map_location=device)
        model.eval()
        return model
    else:
        print('model_name not found')

In [None]:
model_name = 'multitask_model_v08_sum'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v01'
model = load_model(model_name)

## Evaluate model on unified test set

In [None]:
tgt_est = evaluation.compute_regular_data(model, n_points=10, seed=None)

## Evaluate model on data of Kuzmich2017

In [None]:
def predict_on_xyz(model, xyzfile):
    return model.forward(
        schnetpack.data.loader._collate_aseatoms([
            schnetpack.data.atoms.torchify_dict(
                schnetpack.data.atoms._convert_atoms(
                    ase.io.read(xyzfile)
                )
            )
        ])
    )

In [None]:
def evaluate_kuzmich(model, seed=None, n_points=-1):
    kuzmich_dir = os.path.join(data_base_dir, 'Kuzmich2017')
    df = pd.read_csv(os.path.join(kuzmich_dir, 'table1.csv'))
    mapping = {
        'DTDfBTTDPP2': 'DTDfBT(TDPP)2',
        '10_DBFI-MTT': 'DBFI-MTT',
    }
    ambiguous = ['M10']
    
    # Get valid files and establish canonical order
    files = {}
    for f in sorted(os.listdir(os.path.join(kuzmich_dir, 'xyz'))):
        if f.endswith('.xyz'):
            id = f[3:-13]
            if id in mapping:
                id = mapping[id]
            if id in ambiguous:
                print('id: {} ambiguous'.format(f[3:-13]))
                continue
            lb = f[3:-13]
            files[lb] = (f, id)
    # Sort by keys:
    fs = sorted(list(files.items()))

    # Shuffle order
    if seed is not None:
        random_state = np.random.RandomState(seed=seed)
        random_state.shuffle(fs)

    # Compute only on the desired random subset
    ret = {}
    for lb, (f, id) in fs[:n_points]:
        xyzfile = os.path.join(kuzmich_dir, 'xyz', f)
        pred = predict_on_xyz(model, xyzfile)
        est = {k: float(v) for k, v in pred.items()}
        tgt = {
            'LUMO-B3LYP': float(df[df['Acceptor’s Label']==id]['LUMO (eV)'])
        }
        ret[lb] = {
            'tgt': tgt,
            'est': est,
        }
    return ret

In [None]:
# Bring data into the regular format and add it to the
# target-estimates collection:
def add_kuzmich(tgt_est: dict, seed: int = None, n_points: int = -1):
    tgt_est_kuzmich = evaluate_kuzmich(model, seed=seed, n_points=n_points)
    # Drop keys:
    k_data = list(tgt_est_kuzmich.values())
    ret = {
            'tgt': {p: np.array([]) for p in k_data[0]['tgt'].keys()},
            'est': {p: np.array([]) for p in k_data[0]['est'].keys()},
        }
    for kd in k_data:
        for k in ret['tgt'].keys():
            ret['tgt'][k] = np.append(ret['tgt'][k], [kd['tgt'][k]])
        for k in ret['est'].keys():
            ret['est'][k] = np.append(ret['est'][k], [kd['est'][k]])
    tgt_est['Kuzmich2017'] = ret

## Streamlined evaluation and plotting

In [None]:
RANDOMSEED = 26463461

In [None]:
model_name = 'multitask_model_v08'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v08_avg'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v08_sum'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v06'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_only_b3lyp'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v07'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v01'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v01_sum'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v01b_avg'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_v05'
model = load_model(model_name)

In [None]:
model_name = 'multitask_model_only_pbe0'
model = load_model(model_name)

In [None]:
state_path = os.path.join(
    model_dir(model_name), 'best_model_state.pth')
if not os.path.exists(state_path):
    print("Save also the model's also as state dictionary:", state_path)
    torch.save(model.state_dict(), state_path)

## Compute and dump the full error distribution

In [None]:
target_file = os.path.join(model_dir(model_name), 'deviations.npz')
summary_file = os.path.join(model_dir(model_name), 'deviations_summary.json')
if not os.path.exists(summary_file) and not os.path.exists(target_file):
    est_properties = evaluation.get_available_properties(model=model)
    tgt_est = evaluation.compute_regular_data(model, n_points=None, seed=RANDOMSEED)
    add_kuzmich(tgt_est, seed=RANDOMSEED)
    devs = {}
    for test, data in tgt_est.items():
        print(test)
        for p in data['tgt'].keys():
            if p in data['est']:
                print('  ', p)
                devs[test + ':' + p] = data['est'][p] - data['tgt'][p]
    np.savez(target_file, **devs)

## Summarize the deviations (DataFrame, json)

In [None]:
if not os.path.exists(summary_file):
    devs = np.load(target_file)
    import pandas as pd
    summary = pd.DataFrame(columns=['test', 'property', 'mean(error)', 'std(error)', 'MAE', 'RMSE', 'size'])
    for k, dev in devs.items():
        test, p = k.split(':')
        summary = pd.concat([
            summary,
            pd.DataFrame({
                'test': test,
                'property': p,
                'mean(error)': np.mean(dev),
                'std(error)': np.std(dev),
                'MAE': np.mean(np.abs(dev)),
                'RMSE': np.sqrt(np.mean(np.square(dev))),
                'size': len(dev),
            }, index=[0])], ignore_index=True)
    summary.to_json(summary_file, indent=2, orient='records')
else:
    summary = pd.read_json(os.path.join(model_dir(model_name), 'deviations_summary.json'))
summary

## Analyse and plot

In [None]:
n_points = 25
est_properties = evaluation.get_available_properties(model=model)
tgt_est = evaluation.compute_regular_data(model, n_points=n_points, seed=RANDOMSEED)
add_kuzmich(tgt_est, seed=RANDOMSEED, n_points=n_points)

In [None]:
# Define a fixed color code for each test set
for k, color in {
    'qm9': 'orange',
    'alchemy': 'red',
    'oe62': 'purple',
    'hopv': 'blue',
    'TABS': 'green',
    'Kuzmich2017': 'black'
}.items():
    if k in tgt_est:
        tgt_est[k]['color'] = color

In [None]:
properties = ['HOMO', 'LUMO', 'Gap']
theories = ['B3LYP', 'PBE0']

In [None]:
def make_plot(tgt_est: dict, qt_tgt: str, qt_est: str, n_points: int = -1, skiptests=[]):
    def measure_errors(x, y):
        dev = np.array(x) - np.array(y)
        mae = '{:.2f}eV'.format(np.mean(np.abs(dev)))
        rmse = '{:.2f}eV'.format(np.sqrt(np.mean(np.square(dev))))
        return mae, rmse
    def lookup_errors(test, prop):
        test_row = summary[(summary['test']==test) & (summary['property']==prop)]
        assert len(test_row) == 1
        mae = '{:.3f}eV'.format(float(test_row['MAE']))
        rmse = '{:.3f}eV'.format(float(test_row['RMSE']))
        return mae, rmse

    plot_empty = True
    plt.figure(figsize=(5, 5))
    plotname = '{}-{}'.format(model_name, qt_tgt);
    if qt_est != qt_tgt:
        plotname += '-cross'
    for dataset_name, te in tgt_est.items():
        try:
            x = te['tgt'][qt_tgt]
            y = te['est'][qt_est]
        except:
            continue
        if dataset_name in skiptests:
            # Add a tag, but only if skipped due to skiptests: 
            plotname = '{}-skip{}'.format(plotname, dataset_name)
            continue
        if qt_tgt == qt_est:
            try:
                mae, rmse = lookup_errors(test=dataset_name, prop=qt_est)
            except Exception as e:
                print(e)
                print("""Something went wrong looking up {}, {}. Measure 
                errors from plot data.""".format(dataset_name, qt_est))
                mae, rmse = measure_errors(x, y)
        else:
            mae, rmse = measure_errors(x, y)
        print('{}: MAE={}, RMSE={}'.format(dataset_name, mae, rmse))
        plt.scatter(x[:n_points], y[:n_points], color=te['color'], label='{dataset} (MAE={mae})'.format(
            dataset=dataset_name, mae=mae))
        plt.axline((np.mean(x), np.mean(x)), slope=1)
        plt.xlabel('{} target (eV)'.format(qt_tgt))
        plt.ylabel('{} estimate (eV)'.format(qt_est))
        plot_empty = False
    if plot_empty:
        print("{}/{} empty for {}.".format(qt_tgt, qt_est, model_name))
    else:
        plt.title(model_name)
        plt.grid()
        plt.legend()
        tgt_dir = os.path.join(base_dir, 'figures', 'tgt-est')
        os.makedirs(tgt_dir, exist_ok=True)
        plt.savefig(os.path.join(tgt_dir, '{}.pdf'.format(plotname)), dpi=200)
        plt.show()

In [None]:
# Target-estimate plots for each property and theory (diagonal and cross)
for a in properties:
    for t in theories:
        assert len(theories)==2
        t_cross = [th for th in theories if th != t][0]
        q = a + '-' + t
        q_cross = a + '-' + t_cross
        plt.rcParams.update({'axes.facecolor': 'lightgray'})
        make_plot(tgt_est, q, q_cross, n_points=n_points)
        plt.rcParams.update({'axes.facecolor': 'white'})
        make_plot(tgt_est, q, q, n_points=n_points)