## Training Coarse-Grained Crystal Graph Neural Network ($CG^2$-Net) for band gap prediction

### 1. Download data

In this notebook, we show how the coarse-grained crystal graph network can be utilized to quantitatively reproduce DFT-calculated band gap of metal-organic frameworks (MOFs). First, we download [the Quantum MOF (QMOF) database](https://github.com/Andrew-S-Rosen/QMOF), which is used as a source of data.

In [1]:
import urllib
import zipfile


urllib.request.urlretrieve(
    "https://figshare.com/ndownloader/files/31713017", "qmof_database.zip"
)
with zipfile.ZipFile("qmof_database.zip", "r") as zip_ref:
    zip_ref.extractall()

Then, we load data from two JSON files, containing (among many other things) information on MOF crystal structures and target property values.

In [2]:
import json


with open("qmof_database/qmof.json") as f:
    qmof = json.load(f)
with open("qmof_database/qmof_structure_data.json") as f:
    struct_data = json.load(f)

### 2. Prepare coarse-grained crystal graphs

We iterate through `struct_data` and `qmof` lists to form two new lists (`graphs` and `labels`), which contain coarse-grained crystal graphs in the form of [`dgl.heterograph.DGLHeteroGraph`](https://docs.dgl.ai/en/0.9.x/_modules/dgl/heterograph.html) entities and float-type HSE06 band gaps, respectively. MOFs with potentially problematic linkers (e.g., containing pentavalent carbon or boron) are excluded from consideration.

In [3]:
import warnings
from pymatgen.core.structure import Structure

from dgl.base import DGLError
from rdkit.Chem.rdchem import AtomValenceException
from rdkit import RDLogger
from openbabel.pybel import ob

from cgcgnet.featurizer import get_2cg_inputs


warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
RDLogger.DisableLog("rdApp.*")
ob.obErrorLog.SetOutputLevel(0)

graphs, labels = [], []

for q, s in zip(qmof, struct_data):
    if "hse06" in q["outputs"].keys():
        structure = Structure.from_dict(s["structure"])
        try:
            graph = get_2cg_inputs(structure)
            graphs.append(graph)
            labels.append([q["outputs"]["hse06"]["bandgap"]])
        except (KeyError, IndexError, DGLError, AtomValenceException) as err:
            print(q["qmof_id"], err)

qmof-ec9d083 Explicit valence for atom # 6 B, 5, is greater than permitted
qmof-1a0c62d Explicit valence for atom # 4 C, 5, is greater than permitted
qmof-e9351eb Explicit valence for atom # 6 B, 5, is greater than permitted
qmof-ea0f641 Explicit valence for atom # 6 B, 5, is greater than permitted
qmof-abf5294 Explicit valence for atom # 6 B, 5, is greater than permitted
qmof-509d624 Explicit valence for atom # 6 B, 5, is greater than permitted
qmof-6bc32fd Explicit valence for atom # 1 B, 5, is greater than permitted
qmof-cadb719 Explicit valence for atom # 3 C, 5, is greater than permitted
qmof-b78b0ad Explicit valence for atom # 3 C, 5, is greater than permitted
qmof-cd46c95 Explicit valence for atom # 9 B, 5, is greater than permitted
qmof-00a7fd4 Explicit valence for atom # 4 B, 5, is greater than permitted
qmof-9d2522c Explicit valence for atom # 2 B, 5, is greater than permitted
qmof-ca3f908 Explicit valence for atom # 1 B, 5, is greater than permitted
qmof-5bcb068 Explicit val

In [4]:
print(len(graphs))
print(len(labels))

10718
10718


### 3. Prepare PyTorch (DGL) data

$CG^2$-Net framework is implemented using [PyTorch](https://pytorch.org) and [Deep Graph Library (DGL)](https://www.dgl.ai) libraries. All data divided into training, validation, and test subsets are converted into [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) entities.

In [5]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from cgcgnet.utils import get_stratified_folds, get_samples, collate_fn


RANDOM_STATE = 0
BATCH_SIZE = 64

graphs_train_valid, graphs_test, labels_train_valid, labels_test = train_test_split(
    graphs,
    labels,
    test_size=1 / 10,
    random_state=RANDOM_STATE,
    stratify=get_stratified_folds(labels),
)
graphs_train, graphs_valid, labels_train, labels_valid = train_test_split(
    graphs_train_valid,
    labels_train_valid,
    test_size=1 / 9,
    random_state=RANDOM_STATE,
    stratify=get_stratified_folds(labels_train_valid),
)

loader_train = DataLoader(
    get_samples(graphs_train, labels_train),
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn,
)
loader_valid = DataLoader(
    get_samples(graphs_valid, labels_valid),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)
loader_test = DataLoader(
    get_samples(graphs_test, labels_test),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

### 4. Initialize model, loss function, and optimizer

In accordance with our experiments, the [GraphSAGE](https://doi.org/10.48550/arXiv.1706.02216) message-passing mechanism provides the highest overall perfomance in modeling properties of reticular materials. Therefore, here we initialize the `SAGEConvModel` in conjunction with [Adam](https://doi.org/10.48550/arXiv.1412.6980) optimizer; the mean squared error is used as a loss function. The `PATIENCE` parameter corresponds to early-stopping criterion.

In [6]:
import torch
from torch.nn import MSELoss
from torch.optim import Adam

from cgcgnet.nn import SAGEConvModel
from cgcgnet.utils import EarlyStopping


LEARNING_RATE = 1e-3
PATIENCE = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = SAGEConvModel().to(DEVICE)
loss_fn = MSELoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
stopper = EarlyStopping(prefix="hse06_bandgap", patience=PATIENCE)

### 5. Train $CG^2$-Net model

Now, we can train the initialized model using simplistic training loop implemented in the `cgcgnet.nn` module.

In [7]:
from cgcgnet.nn import train_step, valid_step


NUM_EPOCHS = 500

for t in range(NUM_EPOCHS):
    print(f"Epoch {t+1}\n-----------------------------------")
    loss_train = train_step(loader_train, model, loss_fn, optimizer, DEVICE)
    loss_valid = valid_step(loader_valid, model, loss_fn, DEVICE)
    if stopper.step(loss_valid, model):
        break
stopper.load_checkpoint(model)

Epoch 1
-----------------------------------
train loss: 16.072250  [    0/ 8574]
train loss: 0.635301  [ 3200/ 8574]
train loss: 0.423649  [ 6400/ 8574]
valid loss: 0.464042 

Epoch 2
-----------------------------------
train loss: 0.644926  [    0/ 8574]
train loss: 0.396809  [ 3200/ 8574]
train loss: 0.639479  [ 6400/ 8574]
valid loss: 0.354791 

Epoch 3
-----------------------------------
train loss: 0.505838  [    0/ 8574]
train loss: 0.402150  [ 3200/ 8574]
train loss: 0.240160  [ 6400/ 8574]
valid loss: 0.347572 

Epoch 4
-----------------------------------
train loss: 0.255367  [    0/ 8574]
train loss: 0.297830  [ 3200/ 8574]
train loss: 0.213291  [ 6400/ 8574]
valid loss: 0.312485 

Epoch 5
-----------------------------------
train loss: 0.205765  [    0/ 8574]
train loss: 0.178630  [ 3200/ 8574]
train loss: 0.289767  [ 6400/ 8574]
valid loss: 0.393052 

Epoch 6
-----------------------------------
train loss: 0.319856  [    0/ 8574]
train loss: 0.247658  [ 3200/ 8574]
train lo

train loss: 0.076396  [ 6400/ 8574]
valid loss: 0.256360 

Epoch 48
-----------------------------------
train loss: 0.072696  [    0/ 8574]
train loss: 0.061887  [ 3200/ 8574]
train loss: 0.069241  [ 6400/ 8574]
valid loss: 0.251148 

Epoch 49
-----------------------------------
train loss: 0.086514  [    0/ 8574]
train loss: 0.068835  [ 3200/ 8574]
train loss: 0.092119  [ 6400/ 8574]
valid loss: 0.266969 

Epoch 50
-----------------------------------
train loss: 0.072967  [    0/ 8574]
train loss: 0.076100  [ 3200/ 8574]
train loss: 0.077357  [ 6400/ 8574]
valid loss: 0.250441 

Epoch 51
-----------------------------------
train loss: 0.085696  [    0/ 8574]
train loss: 0.217398  [ 3200/ 8574]
train loss: 0.091978  [ 6400/ 8574]
valid loss: 0.266724 

Epoch 52
-----------------------------------
train loss: 0.064540  [    0/ 8574]
train loss: 0.058132  [ 3200/ 8574]
train loss: 0.090952  [ 6400/ 8574]
valid loss: 0.251058 

Epoch 53
-----------------------------------
train loss: 0.03

train loss: 0.038183  [ 3200/ 8574]
train loss: 0.033523  [ 6400/ 8574]
valid loss: 0.256392 

Epoch 95
-----------------------------------
train loss: 0.037934  [    0/ 8574]
train loss: 0.044755  [ 3200/ 8574]
train loss: 0.052380  [ 6400/ 8574]
valid loss: 0.257126 

Epoch 96
-----------------------------------
train loss: 0.029199  [    0/ 8574]
train loss: 0.045754  [ 3200/ 8574]
train loss: 0.049559  [ 6400/ 8574]
valid loss: 0.243797 

Epoch 97
-----------------------------------
train loss: 0.042068  [    0/ 8574]
train loss: 0.085267  [ 3200/ 8574]
train loss: 0.051294  [ 6400/ 8574]
valid loss: 0.255613 

Epoch 98
-----------------------------------
train loss: 0.045131  [    0/ 8574]
train loss: 0.024014  [ 3200/ 8574]
train loss: 0.036452  [ 6400/ 8574]
valid loss: 0.254714 

Epoch 99
-----------------------------------
train loss: 0.045264  [    0/ 8574]
train loss: 0.029253  [ 3200/ 8574]
train loss: 0.045132  [ 6400/ 8574]
valid loss: 0.267090 

Epoch 100
---------------

valid loss: 0.241913 

Epoch 141
-----------------------------------
train loss: 0.035809  [    0/ 8574]
train loss: 0.050801  [ 3200/ 8574]
train loss: 0.031374  [ 6400/ 8574]
valid loss: 0.245128 

Epoch 142
-----------------------------------
train loss: 0.022550  [    0/ 8574]
train loss: 0.030647  [ 3200/ 8574]
train loss: 0.047714  [ 6400/ 8574]
valid loss: 0.243721 

Epoch 143
-----------------------------------
train loss: 0.027198  [    0/ 8574]
train loss: 0.038449  [ 3200/ 8574]
train loss: 0.062847  [ 6400/ 8574]
valid loss: 0.261791 

Epoch 144
-----------------------------------
train loss: 0.026232  [    0/ 8574]
train loss: 0.022608  [ 3200/ 8574]
train loss: 0.030399  [ 6400/ 8574]
valid loss: 0.275967 

Epoch 145
-----------------------------------
train loss: 0.060128  [    0/ 8574]
train loss: 0.023810  [ 3200/ 8574]
train loss: 0.051245  [ 6400/ 8574]
valid loss: 0.256156 

Epoch 146
-----------------------------------
train loss: 0.030541  [    0/ 8574]
train loss

train loss: 0.014797  [ 3200/ 8574]
train loss: 0.012571  [ 6400/ 8574]
valid loss: 0.231345 

Epoch 188
-----------------------------------
train loss: 0.023480  [    0/ 8574]
train loss: 0.010429  [ 3200/ 8574]
train loss: 0.011290  [ 6400/ 8574]
valid loss: 0.230270 

Epoch 189
-----------------------------------
train loss: 0.020286  [    0/ 8574]
train loss: 0.029197  [ 3200/ 8574]
train loss: 0.017984  [ 6400/ 8574]
valid loss: 0.234305 

Epoch 190
-----------------------------------
train loss: 0.037351  [    0/ 8574]
train loss: 0.021629  [ 3200/ 8574]
train loss: 0.050746  [ 6400/ 8574]
valid loss: 0.237284 

Epoch 191
-----------------------------------
train loss: 0.011157  [    0/ 8574]
train loss: 0.017407  [ 3200/ 8574]
train loss: 0.025311  [ 6400/ 8574]
valid loss: 0.243405 

Epoch 192
-----------------------------------
train loss: 0.022489  [    0/ 8574]
train loss: 0.026705  [ 3200/ 8574]
train loss: 0.014640  [ 6400/ 8574]
valid loss: 0.236236 

Epoch 193
----------

### 6. Evaluate performance

Finally, we evaluate predictive performance of trained $CG^2$-Net on test data in terms of well-known regression metrics.

In [8]:
from cgcgnet.nn import predict_dataloader


y_true, y_pred = predict_dataloader(model, loader_test, DEVICE)
y_true, y_pred = y_true.flatten(), y_pred.flatten()

In [9]:
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error


def regression_report(y_true, y_pred):
    metrics = [
        ("mean absolute error (MAE)", mean_absolute_error(y_true, y_pred)),
        (
            "root mean square error (RMSE)",
            mean_squared_error(y_true, y_pred, squared=False),
        ),
        ("coefficient of determination (R^2)", r2_score(y_true, y_pred)),
    ]
    for metric_name, metric_value in metrics:
        print(f"{metric_name:>35s}: {metric_value:>5.2f}")


regression_report(y_true, y_pred)

          mean absolute error (MAE):  0.33
      root mean square error (RMSE):  0.46
 coefficient of determination (R^2):  0.80
