In [None]:
import os
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from omegaconf import OmegaConf
from ase.io import read
from ViSNetGW.model.visnet import create_model
from tqdm import tqdm

In [None]:
def load_model(logs_path, ckpt_name):
    model_file = f"{logs_path}/{ckpt_name}.ckpt"
    config_file = f"{logs_path}/config.yaml"

    cfg = OmegaConf.load(os.path.join(config_file))
    model = create_model(cfg)
    state_dict = torch.load(model_file, map_location=torch.device("cpu"))
    if "swa" in model_file:
        new_state_dict = {key[7:]: value for key, value in state_dict.items() if "module" in key}
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)
    model.eval()
    return model

In [None]:
def test_model(model, test_set, folder, target):

    xyz_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "mols")
    data_path = os.path.join("/Volumes/LaCie/test_sets", test_set, folder)

    xyz_files = os.listdir(xyz_path)
    all_mae = []
    with torch.no_grad():
        for xyz_file in tqdm(xyz_files, leave=False):
            mol = xyz_file[:-4]
            atoms = read(os.path.join(xyz_path, xyz_file), format="xyz")
            #if not (12 <= len(atoms) <= 24):
            #    continue
            #for elem in atoms.get_chemical_symbols():
            #    if elem not in ["H", "C", "N", "O", "F"]:
            #        print(elem)
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1) * (folder[:8] != "E_omol25")
            energies = np.loadtxt(os.path.join(data_path, f"{mol}.dat"))
            if target == "homo":
                y = float(energies[homo_idx])
            elif target == "lumo":
                y = float(energies[homo_idx + 1])
            elif target == "gap":
                homo = float(energies[homo_idx])
                lumo = float(energies[homo_idx + 1])
                y = lumo - homo
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            all_mae.append(abs(y_pred - y))
    
    return all_mae

# Bar plots showing MAE change per sample

In [None]:
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo0M128502532"
ckpt_name = "model_45_epochs"
model = load_model(logs_path, ckpt_name)

logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo5M128502532_transfer_converged/visnet_logs"
ckpt_name = "model_30_epochs"
pretrained_model = load_model(logs_path, ckpt_name)

all_mae = test_model(model, test_set="PC9", folder="E_qp", target="homo")
all_mae_pretrained = test_model(pretrained_model, test_set="PC9", folder="E_qp", target="homo")

In [None]:
n = len(all_mae)
x = [i for i in range(n)]
temp = [(x, y) for x, y in sorted(zip(all_mae, all_mae_pretrained), key=lambda pair: pair[0], reverse=True)]
no_pretrain, pretrain = zip(*temp)

fig, ax = plt.subplots(figsize=(8, 6))

font = {"fontname": "Helvetica", "fontsize": 20}

ax.bar(x, height=no_pretrain, align="edge", width=1.05, label="Without Pretraining")
ax.bar(x, height=pretrain, align="edge", width=1.05, label="With Pretraining")

ax.set_xlabel("Sample Index", **font)
ax.set_xlim([0, n])
xticks = np.arange(stop=3000, step=500)
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, fontname="Helvetica", fontsize=14)

ax.set_ylabel("MAE [eV]", **font)
ax.set_ylim([0, 1.5])
yticks = np.arange(stop=2.0, step=0.2)
ax.tick_params(axis='y', labelsize=14)
ax.set_yticks(yticks)

ax.legend(prop={"family": font["fontname"], "size": 14})

plt.tight_layout()
plt.show()

### MAE Bar Plot for Reference, None, Full and Transfer Finetuning for Homo, Lumo and Gap with Small and Large Models for Homo (only converged models)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

font = {"fontname": "Helvetica", "fontsize": 20}

# Data
categories = ["Reference", "None", "Full", "Transfer"]
N_pre_values = ["0", "1 000 000", "5 000 000", "10 000 000"]

data_small_homo = {
    "Test": {
        "Reference": [0.0518],
        "None": [0.2591, 0.2701, 0.2432],
        "Full": [0.0365, 0.0298, 0.0313],
        "Transfer": [0.0435, 0.0380, 0.0378],
    },
    "PC9": {
        "Reference": [0.2074],
        "None": [0.2910, 0.3071, 0.2720],
        "Full": [0.0895, 0.0751, 0.0701],
        "Transfer": [0.0794, 0.0632, 0.0568],
    },
    "OE62L": {
        "Reference": [0.3751],
        "None": [0.1694, 0.1652, 0.1629],
        "Full": [0.1289, 0.0827, 0.1002],
        "Transfer": [0.1115, 0.0866, 0.0784],
    },
    "OE62H": {
        "Reference": [0.4833],
        "None": [0.1474, 0.1189, 0.1220],
        "Full": [0.1696, 0.1369, 0.1753],
        "Transfer": [0.1233, 0.1041, 0.0948],
    }
}

data_set = "OE62H"
data = data_small_homo[data_set]

# Gradient colors
cmap = plt.cm.Blues
colors = [cmap(i) for i in np.linspace(0.35, 1.05, len(N_pre_values))]

# Bar geometry
bar_width = 0.2         # width of each bar
intra_gap = 0.0         # space between bars within a 3-bar group

# Group widths
W_group = 3*bar_width + 2*intra_gap   # width of groups with 3 bars
W_ref   = bar_width                   # width of the single-bar "Reference" group

# Desired center spacing between multi-bar groups
S_groups = 1.0  # distance between centers of None, Full, Transfer

# Place centers: keep None, Full, Transfer at 1,2,3; shift Reference so edge gaps match
x_groups = np.array([1, 2, 3], dtype=float)
x_ref = x_groups[0] - (S_groups - (W_group - W_ref)/2.0)  # = 1 - (1 - (W_group - W_ref)/2)

# Offsets for bars inside a 3-bar group (left, middle, right), centered on group tick
offsets = np.array([
    -W_group/2 + bar_width/2,                       # left bar
    -W_group/2 + bar_width/2 + (bar_width + intra_gap),  # middle
    -W_group/2 + bar_width/2 + 2*(bar_width + intra_gap) # right
])

fig, ax = plt.subplots(figsize=(8, 6))

# --- Plot Reference (single centered bar) ---
ax.bar(
    x_ref,
    data["Reference"][0],
    width=bar_width,
    label=f"{N_pre_values[0]}",
    color=colors[0]
)

# --- Plot other categories (3 bars each), centered on their ticks with intra-group spacing ---
for i, (N, color) in enumerate(zip(N_pre_values[1:], colors[1:])):  # i = 0..2
    vals = [data[cat][i] for cat in categories[1:]]  # None, Full, Transfer
    ax.bar(
        x_groups + offsets[i],
        vals,
        width=bar_width,
        label=f"{N}",
        color=color
    )

# Axes, ticks, legend
ax.set_title(data_set, fontname="Helvetica", fontsize=25)
ax.set_xlabel("Finetuning", **font)
ax.set_ylabel("MAE [eV]", **font)
xticks = np.concatenate(([x_ref], x_groups))
ax.set_xticks(xticks)
ax.set_xticklabels(categories, fontname="Helvetica", fontsize=14)
yticks = np.arange(stop=0.55, step=0.05)
ax.set_yticks(yticks)
ax.tick_params(axis='y', labelsize=14)
ax.set_ylim(0, yticks[-1])

legend = ax.legend(title="Number Pretraining Samples", prop={"family": font["fontname"], "size": 14})
legend.get_title().set_fontsize(14)

plt.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()


# Reduction of data demand

### Small model

In [None]:
font = {"fontname": "Helvetica", "fontsize": 14}

x = [10000, 20000, 40000, 80000, 120000]
y_gwset = []
y_pc9 = []
plt.scatter(x, y_gwset, label="Test")
plt.scatter(x, y_pc9, label="PC9")
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000])
plt.xlim([0, 125000])
plt.xlabel("Number of Training Samples", **font)
plt.ylabel("MAE [eV]", **font)
plt.legend(fontsize=12)
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

### Large model

In [None]:
font = {"fontname": "Helvetica", "fontsize": 14}

x = [10000, 20000, 40000, 80000, 120000]
y_gwset = [0.1476, 0.1122, 0.0825, 0.0590, 0.0518]
y_pc9 = [0.3693, 0.3231, 0.2744, 0.2220, 0.2074]
y_pre_gwset = [0.0380]
y_pr_pc9 = [0.0632]
plt.scatter(x, y_gwset, label="Test")
plt.scatter(x, y_pc9, label="PC9")
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000])
plt.xlim([0, 125000])
plt.xlabel("Number of Training Samples", **font)
plt.ylabel("MAE [eV]", **font)
plt.legend(fontsize=12)
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
#logs_path = "../visnet_logs"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo10M216502532_transfer_best/visnet_logs"
#ckpt_name = "best_model"
ckpt_name = "model_10_epochs"

model = load_model(logs_path, ckpt_name)

In [None]:
# definitions:
#
# case 1_1: y_gwset < y_omol25 < y_pred
# case 1_2: y_pred < y_omol25 < y_gwset
#
# case 2_1: y_omol25 < y_gwset < y_pred
# case 2_2: y_pred < y_gwset < y_omol25
#
# case 3_1: y_gwset < y_pred < y_omol25
# case 3_2: y_omol25 < y_pred < y_gwset

target = "homo"
lacie_conn = os.path.exists("/Volumes/LaCie")
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")

mae_gwset_case_1_1 = []
mae_omol25_case_1_1 = []
mae_gwset_case_1_2 = []
mae_omol25_case_1_2 = []

mae_gwset_case_2_1 = []
mae_omol25_case_2_1 = []
mae_gwset_case_2_2 = []
mae_omol25_case_2_2 = []

mae_gwset_case_3_1 = []
mae_omol25_case_3_1 = []
mae_gwset_case_3_2 = []
mae_omol25_case_3_2 = []

with torch.no_grad():
    for xyz_file in xyz_files:
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[0])
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[1])
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        mae_gw = abs(y_pred - y_gw)
        mae_dft = abs(y_pred - y_dft)
        if y_gw < y_dft < y_pred:
            mae_gwset_case_1_1.append(mae_gw)
            mae_omol25_case_1_1.append(mae_dft)
        elif y_pred < y_dft < y_gw:
            mae_gwset_case_1_2.append(mae_gw)
            mae_omol25_case_1_2.append(mae_dft)
        elif y_dft < y_gw < y_pred:
            mae_gwset_case_2_1.append(mae_gw)
            mae_omol25_case_2_1.append(mae_dft)
        elif y_pred < y_gw < y_dft:
            mae_gwset_case_2_2.append(mae_gw)
            mae_omol25_case_2_2.append(mae_dft)
        elif y_gw < y_pred < y_dft:
            mae_gwset_case_3_1.append(mae_gw)
            mae_omol25_case_3_1.append(mae_dft)
        elif y_dft < y_pred < y_gw:
            mae_gwset_case_3_2.append(mae_gw)
            mae_omol25_case_3_2.append(mae_dft)
    if lacie_conn:
        pc9_path = "/Volumes/LaCie/test_sets/PC9"
        xyz_files = os.listdir(f"{pc9_path}/mols")
        for xyz_file in xyz_files:
            mol = xyz_file[:-4]
            atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
            e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
            e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
            if target == "homo":
                y_gw = float(e_qp[homo_idx])
                y_dft = float(e_dft[0])
            elif target == "lumo":
                y_gw = float(e_qp[homo_idx+1])
                y_dft = float(e_dft[1])
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            mae_gw = abs(y_pred - y_gw)
            mae_dft = abs(y_pred - y_dft)
            if y_gw < y_dft < y_pred:
                mae_gwset_case_1_1.append(mae_gw)
                mae_omol25_case_1_1.append(mae_dft)
            elif y_pred < y_dft < y_gw:
                mae_gwset_case_1_2.append(mae_gw)
                mae_omol25_case_1_2.append(mae_dft)
            elif y_dft < y_gw < y_pred:
                mae_gwset_case_2_1.append(mae_gw)
                mae_omol25_case_2_1.append(mae_dft)
            elif y_pred < y_gw < y_dft:
                mae_gwset_case_2_2.append(mae_gw)
                mae_omol25_case_2_2.append(mae_dft)
            elif y_gw < y_pred < y_dft:
                mae_gwset_case_3_1.append(mae_gw)
                mae_omol25_case_3_1.append(mae_dft)
            elif y_dft < y_pred < y_gw:
                mae_gwset_case_3_2.append(mae_gw)
                mae_omol25_case_3_2.append(mae_dft)

'''
data = {
    "mol": all_mols,
    "y_pred": all_y_pred,
    "y_gw": all_y_gw,
    "mae_gw": all_mae_gw,
    "y_dft": all_y_dft,
    "mae_dft": all_mae_dft,
}

df = pd.DataFrame(data)

print(f"Mean = {np.mean(all_mae_gw):.4f} +-({np.std(all_mae_gw):.4f}) eV")
print(f"Median = {np.median(all_mae_gw):.4f} eV")
print(f"Min. MAE = {min_mae_gw:.4f} eV for {min_mol_gw}")
print(f"Max. MAE = {max_mae_gw:.4f} eV for {max_mol_gw}")
'''

In [None]:
plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, c="#7aa457") # y_gwset < y_omol25 < y_pred
plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, c="#496234") # y_pred < y_omol25 < y_gwset
plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, c="#9e6ebd") # y_omol25 < y_gwset < y_pred
plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, c="#5f4271") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, c="#cb6751") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, c="#7a3d31") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.show()

In [None]:
#plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, label="case 1") # y_gwset < y_omol25 < y_pred
#plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, label="case 2") # y_pred < y_omol25 < y_gwset
#plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, label="case 1") # y_omol25 < y_gwset < y_pred
#plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, label="case 2") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, label="case 1") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, label="case 2") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, c="#7aa457") # y_gwset < y_omol25 < y_pred
plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, c="#496234") # y_pred < y_omol25 < y_gwset
plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, c="#9e6ebd") # y_omol25 < y_gwset < y_pred
plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, c="#5f4271") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, c="#cb6751") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, c="#7a3d31") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.show()

In [None]:
#plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, label="case 1") # y_gwset < y_omol25 < y_pred
#plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, label="case 2") # y_pred < y_omol25 < y_gwset
#plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, label="case 1") # y_omol25 < y_gwset < y_pred
#plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, label="case 2") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, label="case 1") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, label="case 2") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
#logs_path = "../visnet_logs"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo0M216502532"
ckpt_name = "best_model"
#ckpt_name = "model_10_epochs"

model = load_model(logs_path, ckpt_name)

#logs_path = "../visnet_logs"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo10M216502532_best/visnet_logs"
#ckpt_name = "best_model"
ckpt_name = "model_10_epochs"

pretrained_model = load_model(logs_path, ckpt_name)

In [None]:
plt.scatter(all_mae_gwset, all_mae_omol25, label="Without Pretraining", s=15.0)
plt.scatter(all_mae_gwset_pretrain, all_mae_omol25_pretrain, label="With Pretraining", s=15.0)
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.legend()
plt.show()

In [None]:
logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/prehomo10M216502532"
#ckpt_name = "best_model"
ckpt_name = "model_5_epochs"

target = "gap"
lacie_conn = os.path.exists("/Volumes/LaCie")
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")
model = load_model(logs_path, ckpt_name)

all_target = []
all_pred = []
all_mae_pc9 = []
all_mae_omol25 = []

with torch.no_grad():
    for xyz_file in xyz_files:
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[0])
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[1])
        elif target == "gap":
            homo_gw = float(e_qp[homo_idx])
            homo_dft = float(e_dft[0])
            lumo_gw = float(e_qp[homo_idx+1])
            lumo_dft = float(e_dft[1])
            y_gw = lumo_gw - homo_gw
            y_dft= lumo_dft - homo_dft
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        mae_gw = abs(y_pred - y_gw)
        mae_dft = abs(y_pred - y_dft)
        all_mae_pc9.append(mae_gw)
        all_mae_omol25.append(mae_dft)
        all_target.append(y_gw)
        all_pred.append(y_pred)
    if lacie_conn:
        pc9_path = "/Volumes/LaCie/test_sets/PC9"
        xyz_files = os.listdir(f"{pc9_path}/mols")
        for xyz_file in xyz_files:
            mol = xyz_file[:-4]
            atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
            e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
            e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
            if target == "homo":
                y_gw = float(e_qp[homo_idx])
                y_dft = float(e_dft[0])
            elif target == "lumo":
                y_gw = float(e_qp[homo_idx+1])
                y_dft = float(e_dft[1])
            elif target == "gap":
                homo_gw = float(e_qp[homo_idx])
                homo_dft = float(e_dft[0])
                lumo_gw = float(e_qp[homo_idx+1])
                lumo_dft = float(e_dft[1])
                y_gw = lumo_gw - homo_gw
                y_dft= lumo_dft - homo_dft
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            mae_gw = abs(y_pred - y_gw)
            mae_dft = abs(y_pred - y_dft)
            all_mae_pc9.append(mae_gw)
            all_mae_omol25.append(mae_dft)
            all_target.append(y_gw)
            all_pred.append(y_pred)

In [None]:
plt.hist(all_target, bins=50, density=True, label="Target")
plt.hist(all_pred, bins=50, density=True, label="Prediction")
plt.legend()
plt.show()

In [None]:
plt.hist(all_target, bins=50, density=True, label="Target")
plt.hist(all_pred, bins=50, density=True, label="Prediction")
plt.legend()
plt.show()