In [1]:
import csv
import dill
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from tabulate import tabulate
import matplotlib.pyplot as plt

from datasets.MD17.MD17Dataset import MD17SingleDataset
from datasets.RMD17.RMD17Dataset import RMD17SingleDataset
from datasets.SMD17.SMD17Dataset import SMD17SingleDataset
from scripts.Chemistry.losses import EnergyLoss, PosForceLoss
from scripts.commom_util import *

In [2]:
def get_configs(model_path):
    configs = dict()
    with open(f"{model_path}/info.txt") as fp:
        for line in fp:
            line = line.strip().split(" ")
            if len(line)>1:
                configs[line[0]] = line[1]

    return configs

def predict(model_path, epoch, configs, device, root):
    # initialize
    if configs["dataset"]=="MD17SingleDataset":
        dataset = MD17SingleDataset(configs["style"], configs["molecule"], "test", configs["split"], root)
    elif configs["dataset"]=="SMD17SingleDataset":
        dataset = SMD17SingleDataset(configs["style"], configs["molecule"], "test", configs["split"], root)
    else:
        dataset = RMD17SingleDataset(configs["style"], configs["molecule"], "test", configs["split"], root)
    identifier = dataset.identifier
    test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate)
    
    path = glob.glob(f"{model_path}/{epoch:03d}_*.pth")
    if not path: raise ValueError("model not found")
    model = torch.load(path[0], map_location=torch.device(torch.cuda.current_device()), pickle_module=dill)
    print(f"Using model {path[0]}")
    model.eval()
    
    # test
    preds = []
    losses = []
    tq = tqdm(test_dataloader)
    for data, label in tq:
        data = {i:v.to(device) for i, v in data.items()}
        label = {i:v.to(device) for i, v in label.items()}
        pred = model(data)
        preds.append( ((torch.cat((pred["E"], pred["F"].reshape(1, -1)), axis=1)).squeeze()).tolist() )
     
        loss_E = (EnergyLoss(pred, label).to("cpu").item())
        loss_F = PosForceLoss(pred, label)
        loss_F = [ (l.to("cpu").item()) for l in loss_F]
        losses.append([loss_E]+loss_F)
    
    # save result
    save_reult(model_path, identifier, preds, losses)

def metric(model_path, configs, lrs=0, if_print=False):
    ds = configs["dataset"]

    # load and calculate
    with open(f"{model_path}/loss_{ds}_val_ood.csv", newline='') as fp:
        cdata = list(csv.reader(fp, quoting=csv.QUOTE_NONNUMERIC))
        if lrs: cdata = [[c*lrs for c in row] for row in cdata]
        loss_Emole , loss_Fmole= [], []
        hit, total = 0, len(cdata)
        failE, failF = 0, 0

        #for row in tqdm(cdata):
        for row in cdata:
            loss_Emole.append(row[0]) 
            force = [f for f in row[1:] if f!=""]
            loss_Fmole.append(sum(force)/len(force))

            m = max(force)
            if row[0]<=0.02 and m<=0.03: hit += 1
            else:
                if row[0]>0.02: failE += 1
                if m>0.03: failF += 1
        loss_Emole , loss_Fmole= np.array(loss_Emole), np.array(loss_Fmole)

        EMAE = np.mean(loss_Emole)
        FMAE = np.mean(loss_Fmole)
        EFWT = hit/total
        if if_print:
            print(f"Energy MAE: {EMAE:.3f},\tForce MAE: {FMAE:.3f},\tEFwT: {EFWT:.3f} (failE:{(failE/total):.3f}, failF:{(failF/total):.3f})\n")
        return EMAE, FMAE, EFWT

In [4]:
# Comparison with and without IReLU, denormalization, label rescaling
# Model: SchNet, CGCNN, DimeNet++, ForceNet, GemNet
# Dataset: OC22

need_predict = False
EMAE_table = [["model", "loss"]]
FMAE_table = [["model", "loss"]]
EFWT_table = [["model", "loss"]]

for model in ["schnet", "schnet*", "cgcnn", "cgcnn*", "dimenet2", "dimenet2*", "forcenet", "forcenet*"]:
    EMAE_row = [model]
    FMAE_row = [model]
    EFWT_row = [model]
    
    if model == "schnet*":
        model_path = f"./checkpoints/{model[:-1]}_SOC22_IR"
        lrs = 10000
    elif model == "cgcnn*":
        model_path = f"./checkpoints/{model[:-1]}nobn_SOC22_IR"
        lrs = 10000
    elif model == "dimenet2*":
        model_path = f"./checkpoints/{model[:-1]}_SOC22_IR"
        lrs = 10000
    elif model == "forcenet*":
        model_path = f"./checkpoints/{model[:-1]}nobn_SOC22_IR"
        lrs = 100000
    else:
        model_path = f"./checkpoints/{model}_OC22_og"
        lrs = 0
    cfg = get_configs(model_path)

    if need_predict:
        root = "../../datasets/SMD17/datas"
        predict(model_path, 49, cfg, "cuda", root)
    EMAE, FMAE, EFWT = metric(model_path, cfg, lrs, if_print=False)
    EMAE_row.append(EMAE)
    FMAE_row.append(FMAE)
    EFWT_row.append(EFWT)
    
    EMAE_table.append(EMAE_row)
    FMAE_table.append(FMAE_row)
    EFWT_table.append(EFWT_row)

tableE = tabulate(EMAE_table, headers='firstrow', tablefmt='grid')
tableF = tabulate(FMAE_table, headers='firstrow', tablefmt='grid')

print("Energy Loss")
print(tableE)
print("")
print("Force Loss")
print(tableF)

Energy Loss
+-----------+------------+
| model     |       loss |
| schnet    |   6.94179  |
+-----------+------------+
| schnet*   |  22.4621   |
+-----------+------------+
| cgcnn     | 223.904    |
+-----------+------------+
| cgcnn*    |  71.3368   |
+-----------+------------+
| dimenet2  |   5.10658  |
+-----------+------------+
| dimenet2* |   0.576522 |
+-----------+------------+
| forcenet  |   4.48477  |
+-----------+------------+
| forcenet* |   5.90202  |
+-----------+------------+

Force Loss
+-----------+-----------+
| model     |      loss |
| schnet    | 0.109957  |
+-----------+-----------+
| schnet*   | 0.128023  |
+-----------+-----------+
| cgcnn     | 0.324934  |
+-----------+-----------+
| cgcnn*    | 0.0829506 |
+-----------+-----------+
| dimenet2  | 0.0999023 |
+-----------+-----------+
| dimenet2* | 0.0107285 |
+-----------+-----------+
| forcenet  | 0.331259  |
+-----------+-----------+
| forcenet* | 0.102364  |
+-----------+-----------+
