# MODNet 'matbench_jdft2d' benchmarking

Matbench v0.1 test dataset for predicting exfoliation energies from crystal structure (computed with the OptB88vdW and TBmBJ functionals). Adapted from the JARVIS DFT database. For benchmarking w/ nested cross validation, the order of the dataset must be identical to the retrieved data; refer to the Automatminer/Matbench publication for more details.

In [None]:
from collections import defaultdict
import itertools
import os
import pandas as pd
import matplotlib.pyplot as plt 
import numpy as np
from matminer.datasets import load_dataset, get_all_dataset_info
from IPython.display import Markdown
from modnet.preprocessing import MODData
from modnet.models import MODNetModel
from modnet.featurizers import MODFeaturizer
from modnet.featurizers.presets import DeBreuck2020Featurizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
Markdown(filename="./README.md")

## Data exploration

In [None]:
df = load_dataset("matbench_jdft2d")

In [None]:
df.columns

### Target space

In [None]:
df.describe()

In [None]:
fig, ax = plt.subplots(facecolor="w")
ax.hist(df["exfoliation_en"], bins=100, density=True);
ax.set_ylabel("Frequency")
ax.set_xlabel("Exfoliation energy (meV)")

## Featurization and feature selection

First, we define some convenience classes that pass wraps composition data in a fake structure containe, and we define a composition only featurizer preset based on `DeBreuck2020Featurizer`.

In [None]:
PRECOMPUTED_MODDATA = "./precomputed/matbench_jdft2d.pkl.gz"

if os.path.isfile(PRECOMPUTED_MODDATA):
    data = MODData.load(PRECOMPUTED_MODDATA)
else:
    # Use a fresh copy of the dataset
    df = load_dataset("matbench_jdft2d")
    
    data = MODData(
        structures=df["structure"].tolist(), 
        targets=df["exfoliation_en"].tolist(), 
        target_names=["Exfoliation energy (meV)"],
        featurizer=DeBreuck2020Featurizer(n_jobs=8)
    )
    data.featurize()
    data.feature_selection(n=650)
    data.save(PRECOMPUTED_MODDATA)

In [None]:
#data.optimal_features=None
#data.cross_nmi = None
#data.num_classes = {"Exfoliation energy (meV)":0}
#data.feature_selection(n=-1)
#data.save("./precomputed/matbench_jdft2d_MPCNMI.pkl.gz")

## Training

In [None]:
try:
    plot_benchmark
except:
    import sys
    sys.path.append('..')
    from modnet_matbench.utils import *
from sklearn.model_selection import KFold
from modnet.models import MODNetModel
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

data.df_targets.rename(columns={data.target_names[0]: "E"}, inplace=True)

best_settings = {
    "increase_bs":False,
    "num_neurons": [[256], [64], [32], [32]],
    "n_feat": 100,
    "lr": 0.01,
    "epochs": 1000,
    "act": "relu",
    "batch_size": 32,
    "loss": "mae",
}     

results = matbench_benchmark(data, [[["E"]]], {"E": 1}, best_settings, save_folds=True)
np.mean(results['scores'])

In [None]:
for i in range(5):
    plt.plot(results["models"][i].history.history["loss"][50:])

In [None]:
import seaborn as sns
reg_df = pd.DataFrame(
    np.array([
        [x for targ in results["targets"] for x in targ],
        [y for pred in results["predictions"] for y in pred],
        [e for err in results["errors"] for e in err]
    ]).T,
    columns=["targets", "predictions", "errors"]
)
splits = []
for i in range(5):
    for j in range(len(results["targets"][i])):
        splits.append(i)
reg_df["split"] = splits

In [None]:
fig, ax = plt.subplots()
ax.set_aspect("equal")
sns.scatterplot(data=reg_df, x="targets", y="predictions", hue="split", palette="Dark2", ax=ax, alpha=0.5)
sns.regplot(data=reg_df, x="targets", y="predictions", ax=ax, scatter=False)
plt.xlabel("True")
plt.ylabel("Pred.")

In [None]:
g = sns.jointplot(data=reg_df, x="errors", y="predictions", hue="split", palette="Dark2", alpha=0.0, marginal_kws={"shade": False})
g.plot_joint(sns.scatterplot, hue=None, c="black", s=5, alpha=0.8)
g.plot_joint(sns.kdeplot, color="split", zorder=0, levels=5, alpha=0.5)

In [None]:
sns.kdeplot(data=reg_df, x="targets", y="predictions", hue="split", shade=False, levels=3, palette="Dark2", alpha=0.5, )