In [None]:
import os
import torch

import numpy as np
import matplotlib.pyplot as plt

from omegaconf import OmegaConf
from ase.io import read
from ViSNetGW.model.visnet import create_model
from tqdm import tqdm

In [None]:
def load_model(logs_path, ckpt_name):
    model_file = f"{logs_path}/{ckpt_name}.ckpt"
    config_file = f"{logs_path}/config.yaml"

    cfg = OmegaConf.load(os.path.join(config_file))
    model = create_model(cfg)
    state_dict = torch.load(model_file, map_location=torch.device("cpu"))
    if "swa" in model_file:
        new_state_dict = {key[7:]: value for key, value in state_dict.items() if "module" in key}
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)
    model.eval()
    return model

In [None]:
def test_model(model, test_set, folder, target, limit_mol_size=False):

    xyz_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "mols")
    data_path = os.path.join("/Volumes/LaCie/test_sets", test_set, folder)

    xyz_files = os.listdir(xyz_path)
    all_mae = []
    with torch.no_grad():
        for xyz_file in tqdm(xyz_files, leave=False):
            mol = xyz_file[:-4]
            atoms = read(os.path.join(xyz_path, xyz_file), format="xyz")
            if limit_mol_size:
                if not (12 <= len(atoms) <= 24):
                    continue
            #for elem in atoms.get_chemical_symbols():
            #    if elem not in ["H", "C", "N", "O", "F"]:
            #        print(elem)
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1) * (folder[:8] != "E_omol25")
            energies = np.loadtxt(os.path.join(data_path, f"{mol}.dat"))
            if target == "homo":
                y = float(energies[homo_idx])
            elif target == "lumo":
                y = float(energies[homo_idx + 1])
            elif target == "gap":
                homo = float(energies[homo_idx])
                lumo = float(energies[homo_idx + 1])
                y = lumo - homo
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            all_mae.append(abs(y_pred - y))
    
    return all_mae



def test_model_gwset(model, target):
    base_path = "/Volumes/LaCie/test_sets/GWSet"
    N = torch.load(os.path.join(base_path, "N.pt"))
    Z = torch.load(os.path.join(base_path, "Z.pt"))
    R = torch.load(os.path.join(base_path, "R.pt"))
    M = torch.load(os.path.join(base_path, "M.pt"))
    E = torch.load(os.path.join(base_path, f"{target}.pt"))

    all_mae = []
    with torch.no_grad():
        for i in tqdm(range(2885), leave=False):
            data = {"z": Z[i, M[i]], "pos": R[i, M[i], :], "batch": torch.tensor([0 for _ in range(N[i].item())])}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            y = E[i].item()
            #print(y_pred, y)
            all_mae.append(abs(y_pred - y))
    
    return all_mae

# Bar plots showing MAE change per sample

In [None]:
logs_path = "/Volumes/LaCie/trained_models/ViSNet/gap0M128502532"
ckpt_name = "model_55_epochs"
model = load_model(logs_path, ckpt_name)

logs_path = "/Volumes/LaCie/trained_models/ViSNet/gap5M128502532_transfer_converged"
ckpt_name = "model_35_epochs"
pretrained_model = load_model(logs_path, ckpt_name)

all_mae = test_model(model, test_set="PC9", folder="E_qp", target="gap", limit_mol_size=False)
all_mae_pretrained = test_model(pretrained_model, test_set="PC9", folder="E_qp", target="gap", limit_mol_size=False)

In [None]:
logs_path = "/Volumes/LaCie/trained_models/ViSNet/gap0M128502532"
ckpt_name = "model_55_epochs"
model = load_model(logs_path, ckpt_name)

logs_path = "/Volumes/LaCie/trained_models/ViSNet/gap5M128502532_transfer_converged"
ckpt_name = "model_35_epochs"
pretrained_model = load_model(logs_path, ckpt_name)

all_mae = test_model_gwset(model, target="GAP")
all_mae_pretrained = test_model_gwset(pretrained_model, target="GAP")

In [None]:
n = len(all_mae)
x = [i for i in range(n)]
temp = [(x, y) for x, y in sorted(zip(all_mae, all_mae_pretrained), key=lambda pair: pair[0], reverse=True)]
no_pretrain, pretrain = zip(*temp)

fig, ax = plt.subplots(figsize=(8, 6))

font = {"fontname": "Helvetica", "fontsize": 20}
cmap = plt.cm.Blues

ax.bar(x, height=no_pretrain, align="edge", width=1.05, label="Without Pretraining", color=cmap(0.4))
ax.bar(x, height=pretrain, align="edge", width=1.05, label="With Pretraining", color=cmap(0.9))

ax.set_title("Test", fontname="Helvetica", fontsize=25)

ax.set_xlabel("Sample Index", **font)
ax.set_xlim([0, n])
xticks = np.arange(stop=n, step=500)
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, fontname="Helvetica", fontsize=14)

ax.set_ylabel("MAE [eV]", **font)
ax.set_ylim([0, 2.4])
yticks = np.arange(stop=2.6, step=0.2)
ax.tick_params(axis='y', labelsize=14)
ax.set_yticks(yticks)

ax.legend(prop={"family": font["fontname"], "size": 14})

plt.tight_layout()
plt.show()

# MAE bar plots

### Homo Small and Large Models

In [None]:
import matplotlib.pyplot as plt
import numpy as np

font = {"fontname": "Helvetica", "fontsize": 20}

# Data
categories = ["Reference", "None", "Full", "Transfer"]
N_pre_values = ["0", "1 000 000", "5 000 000", "10 000 000"]

data_homo_small = {
    "Test": {
        "Reference": [0.0518],
        "None": [0.2591, 0.2701, 0.2432],
        "Full": [0.0365, 0.0298, 0.0313],
        "Transfer": [0.0435, 0.0380, 0.0378],
    },
    "PC9": {
        "Reference": [0.2074],
        "None": [0.2910, 0.3071, 0.2720],
        "Full": [0.0895, 0.0751, 0.0701],
        "Transfer": [0.0794, 0.0632, 0.0568],
    },
    "OE62L": {
        "Reference": [0.3751],
        "None": [0.1694, 0.1652, 0.1629],
        "Full": [0.1289, 0.0827, 0.1002],
        "Transfer": [0.1115, 0.0866, 0.0784],
    },
    "OE62H": {
        "Reference": [0.4833],
        "None": [0.1474, 0.1189, 0.1220],
        "Full": [0.1696, 0.1369, 0.1753],
        "Transfer": [0.1233, 0.1041, 0.0948],
    }
}

data_homo_large = {
    "Test": {
        "Reference": [0.0378],
        "None": [0.2697, 0.2561, 0.2322],
        "Full": [0.0380, 0.0280, 0.0285],
        "Transfer": [0.0452, 0.0297, 0.0294],
    },
    "PC9": {
        "Reference": [0.1973],
        "None": [0.2980, 0.2914, 0.2668],
        "Full": [0.0880, 0.0555, 0.0454],
        "Transfer": [0.0769, 0.0517, 0.0503],
    },
    "OE62L": {
        "Reference": [0.2375],
        "None": [0.1714, 0.1588, 0.1662],
        "Full": [0.1160, 0.0870, 0.0709],
        "Transfer": [0.1332, 0.0775, 0.0707],
    },
    "OE62H": {
        "Reference": [0.5183],
        "None": [0.1565, 0.1686, 0.1334],
        "Full": [0.1788, 0.1247, 0.1032],
        "Transfer": [0.1595, 0.1108, 0.1084],
    }
}

data = data_homo_small
data_set = "OE62H"
data = data[data_set]

# Gradient colors
cmap = plt.cm.Blues
colors = [cmap(i) for i in np.linspace(0.35, 1.05, len(N_pre_values))]

# Bar geometry
bar_width = 0.2         # width of each bar
intra_gap = 0.0         # space between bars within a 3-bar group

# Group widths
W_group = 3*bar_width + 2*intra_gap   # width of groups with 3 bars
W_ref   = bar_width                   # width of the single-bar "Reference" group

# Desired center spacing between multi-bar groups
S_groups = 1.0  # distance between centers of None, Full, Transfer

# Place centers: keep None, Full, Transfer at 1,2,3; shift Reference so edge gaps match
x_groups = np.array([1, 2, 3], dtype=float)
x_ref = x_groups[0] - (S_groups - (W_group - W_ref)/2.0)  # = 1 - (1 - (W_group - W_ref)/2)

# Offsets for bars inside a 3-bar group (left, middle, right), centered on group tick
offsets = np.array([
    -W_group/2 + bar_width/2,                       # left bar
    -W_group/2 + bar_width/2 + (bar_width + intra_gap),  # middle
    -W_group/2 + bar_width/2 + 2*(bar_width + intra_gap) # right
])

fig, ax = plt.subplots(figsize=(8, 6))

# --- Plot Reference (single centered bar) ---
ax.bar(
    x_ref,
    data["Reference"][0],
    width=bar_width,
    label=f"{N_pre_values[0]}",
    color=colors[0]
)

# --- Plot other categories (3 bars each), centered on their ticks with intra-group spacing ---
for i, (N, color) in enumerate(zip(N_pre_values[1:], colors[1:])):  # i = 0..2
    vals = [data[cat][i] for cat in categories[1:]]  # None, Full, Transfer
    ax.bar(
        x_groups + offsets[i],
        vals,
        width=bar_width,
        label=f"{N}",
        color=color
    )

# Axes, ticks, legend
ax.set_title(data_set, fontname="Helvetica", fontsize=25)
ax.set_xlabel("Finetuning", **font)
ax.set_ylabel("MAE [eV]", **font)
xticks = np.concatenate(([x_ref], x_groups))
ax.set_xticks(xticks)
ax.set_xticklabels(categories, fontname="Helvetica", fontsize=14)
yticks = np.arange(stop=0.60, step=0.05)
ax.set_yticks(yticks)
ax.tick_params(axis='y', labelsize=14)
ax.set_ylim(0, yticks[-1])

legend = ax.legend(title="Number Pretraining Samples", prop={"family": font["fontname"], "size": 14})
legend.get_title().set_fontsize(14)

plt.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()


# Lumo and Gap

In [None]:
font = {"fontname": "Helvetica", "fontsize": 20}

# Data (replace with your GAP or LUMO dict)
data_lumo = {
    "Test": {
        "Reference": [0.0338],
        "None": [0.7496, 0.7301],
        "Transfer": [0.0336, 0.0321],
    },
    "PC9": {
        "Reference": [0.1258],
        "None": [0.6991, 0.6761],
        "Transfer": [0.0571, 0.0530],
    },
    "OE62L": {
        "Reference": [0.3075],
        "None": [0.4265, 0.4361],
        "Transfer": [0.1273, 0.1172],
    },
    "OE62H": {
        "Reference": [0.3530],
        "None": [0.3431, 0.3513],
        "Transfer": [0.1670, 0.1507],
    },
}

data_gap = {
    "Test": {
        "Reference": [0.0658],
        "None": [0.9825, 0.9799],
        "Transfer": [0.0646, 0.0513],
    },
    "PC9": {
        "Reference": [0.2696],
        "None": [0.9578, 0.9633],
        "Transfer": [0.1249, 0.0951],
    },
    "OE62L": {
        "Reference": [0.5245],
        "None": [0.5320, 0.5230],
        "Transfer": [0.2741, 0.2322],
    },
    "OE62H": {
        "Reference": [0.5187],
        "None": [0.4427, 0.4112],
        "Transfer": [0.3116, 0.2559],
    }
}

data = data_lumo
data_set = "OE62H"
data = data[data_set]

# Hard-coded categories and N_pre values
categories = ["Reference", "None", "Transfer"]
N_pre_values = ["0", "1 000 000", "5 000 000"]

# Gradient colors
cmap = plt.cm.Blues
colors = [cmap(i) for i in np.linspace(0.35, 1.05, len(N_pre_values))]

# Bar geometry
bar_width = 0.25
intra_gap = 0.0

# Group widths
W_group = 2*bar_width + intra_gap   # width of groups with 2 bars
W_ref   = bar_width                 # single bar for Reference
S_groups = 1.0

# X positions
x_groups = np.array([1, 2], dtype=float)   # positions for None, Transfer
x_ref = x_groups[0] - (S_groups - (W_group - W_ref)/2.0)

# Offsets for 2 bars inside a group
offsets = np.array([
    -W_group/2 + bar_width/2,               # left
    -W_group/2 + bar_width/2 + (bar_width + intra_gap)  # right
])

fig, ax = plt.subplots(figsize=(8, 6))

# --- Plot Reference (single bar) ---
ax.bar(
    x_ref,
    data["Reference"][0],
    width=bar_width,
    label=f"{N_pre_values[0]}",
    color=colors[0]
)

# --- Plot other categories (2 bars each) ---
for i, (N, color) in enumerate(zip(N_pre_values[1:], colors[1:])):  # i = 0..1
    vals = [data[cat][i] for cat in categories[1:]]  # None, Transfer
    ax.bar(
        x_groups + offsets[i],
        vals,
        width=bar_width,
        label=f"{N}",
        color=color
    )

# Axes, ticks, legend
ax.set_title(data_set, fontname="Helvetica", fontsize=25)
ax.set_xlabel("Finetuning", **font)
ax.set_ylabel("MAE [eV]", **font)
xticks = np.concatenate(([x_ref], x_groups))
ax.set_xticks(xticks)
ax.set_xticklabels(categories, fontname="Helvetica", fontsize=14)
yticks = np.arange(stop=0.9, step=0.1)  # adjust depending on dataset
ax.set_yticks(yticks)
ax.tick_params(axis='y', labelsize=14)
ax.set_ylim(0, yticks[-1])

legend = ax.legend(title="Number Pretraining Samples", prop={"family": font["fontname"], "size": 14})
legend.get_title().set_fontsize(14)

plt.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

### Bar plots for talk before and after bitter lesson

In [None]:
font = {"fontname": "Helvetica", "fontsize": 20}

data_sets = ["Test", "PC9", "OE62"]

data = {
    "homo reference": [0.0518, 0.2074, 0.3751],
    "homo reference 2": [0.0518, 0.0, 0.0],
    "homo reference 3": [0.0518, 0.2074, 0.0],
    "gap reference": [0.0658, 0.2696, 0.5245],
    "gap reference 2": [0.0658, 0.0, 0.0],
    "gap reference 3": [0.0658, 0.2696, 0.0],
    "lumo reference": [0.0338, 0.1258, 0.3075],
    "lumo reference 2": [0.0338, 0.0, 0.0],
    "lumo reference 3": [0.0338, 0.1258, 0.0],
    "homo small transfer 5M": [0.0298, 0.0751, 0.0827],
    "homo small transfer 10M": [0.0378, 0.0568, 0.0784],
    "gap small transfer 5M": [0.0513, 0.0951, 0.2322],
    "lumo small transfer 5M": [0.0321, 0.0530, 0.1172]
}

all_mae = data["gap reference 2"]

# Gradient colors
cmap = plt.cm.Blues
colors = [cmap(i) for i in np.linspace(0.35, 1.05, 3)]

# Bar geometry
bar_width = 0.5

fig, ax = plt.subplots(figsize=(8, 6))

for i, (data_set, mae) in enumerate(zip(data_sets, all_mae)):
    ax.bar(
        x=data_set,
        height=mae,
        width=bar_width,
        color=cmap(0.75)
    )

# Axes, ticks, legend
ax.set_title("Gap Prediction", fontname="Helvetica", fontsize=25)
ax.set_xlabel("Dataset", **font)
ax.set_ylabel("MAE [eV]", **font)
ax.set_xticks([0, 1, 2])
ax.set_xticklabels(data_sets, fontname="Helvetica", fontsize=14)
yticks = np.arange(stop=0.59, step=0.05)  # adjust depending on dataset
ax.set_yticks(yticks)
ax.tick_params(axis='y', labelsize=14)
ax.set_ylim(0, yticks[-1])

plt.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

# Reduction of data demand

In [None]:
# Data
N_fine = np.array([10000, 20000, 40000, 80000, 120000])

data_homo = {
    "Test": {
        "with": np.array([0.0582, 0.0487, 0.0413, 0.0387, 0.0380]),
        "without": np.array([0.1476, 0.1122, 0.0848, 0.0590, 0.0518]),
    },
    "PC9": {
        "with": np.array([0.0717, 0.0614, 0.0630, 0.0587, 0.0632]),
        "without": np.array([0.3693, 0.3231, 0.2735, 0.2220, 0.2074]),
    },
    "OE62L": {
        "with": np.array([0.1034, 0.1060, 0.0848, 0.0922, 0.0866]),
        "without": np.array([0.4414, 0.3918, 0.3955, 0.3521, 0.3751]),
    },
    "OE62H": {
        "with": np.array([0.1149, 0.1247, 0.1024, 0.1079, 0.1041]),
        "without": np.array([0.6500, 0.6346, 0.5514, 0.4913, 0.4833]),
    },
}

data_gap = {
    "Test": {
        "with": np.array([0.0818, 0.0758, 0.0617, 0.0566, 0.0513]),
        "without": np.array([0.1954, 0.1508, 0.1126, 0.0852, 0.0658]),
    },
    "PC9": {
        "with": np.array([0.1048, 0.0993, 0.0997, 0.0975, 0.0951]),
        "without": np.array([0.4708, 0.4017, 0.3596, 0.3081, 0.2696]),
    },
    "OE62L": {
        "with": np.array([0.2576, 0.2433, 0.2265, 0.2223, 0.2322]),
        "without": np.array([0.7663, 0.7301, 0.5581, 0.5146, 0.5245]),
    },
    "OE62H": {
        "with": np.array([0.3109, 0.2924, 0.2736, 0.2478, 0.2559]),
        "without": np.array([0.7555, 0.6693, 0.5933, 0.5352, 0.5187]),
    },
}

data_lumo = {
    "Test": {
        "with": np.array([0.0481, 0.0416, 0.0369, 0.0324, 0.0321]),
        "without": np.array([0.0893, 0.0749, 0.0569, 0.0364, 0.0338]),
    },
    "PC9": {
        "with": np.array([0.0638, 0.0585, 0.0540, 0.0539, 0.0530]),
        "without": np.array([0.1946, 0.1859, 0.1571, 0.1302, 0.1258]),
    },
    "OE62L": {
        "with": np.array([0.1353, 0.1220, 0.1249, 0.1155, 0.1172]),
        "without": np.array([0.3743, 0.3521, 0.5381, 0.2970, 0.3075]),
    },
    "OE62H": {
        "with": np.array([0.1692, 0.1575, 0.1616, 0.1468, 0.1507]),
        "without": np.array([0.3766, 0.3884, 0.4218, 0.3506, 0.3530]),
    },
}

data = data_homo
data_set = "OE62L"
with_pretraining = data[data_set]["with"]
without_pretraining = data[data_set]["without"]

with_pretraining -= with_pretraining[-1]
#with_pretraining /= with_pretraining[0]
without_pretraining -= without_pretraining[-1]
#without_pretraining /= without_pretraining[0]

cmap = plt.cm.Blues
color1, color2 = cmap(0.5), cmap(1.0)

# Plot
fig, ax = plt.subplots(figsize=(8, 6))
font = {"fontname": "Helvetica", "fontsize": 20}

ax.plot(N_fine, without_pretraining, label="Without Pretraining", color=color1, marker="s", markersize=8)
ax.plot(N_fine, with_pretraining, label="With Pretraining", color=color2, marker="o", markersize=8)

ax.set_title(data_set, fontname="Helvetica", fontsize=25)

ax.set_xlabel("Number of Finetuning Samples", **font)
ax.set_xlim([10000, 120000])
xticks = N_fine
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, fontname="Helvetica", fontsize=14)

y_lower_lim = -0.03
y_upper_lim = 0.08
step = 0.02
ax.set_ylabel("MAE [eV]", **font)
#ax.set_ylim([-0.005, without_pretraining[0]+0.01])
ax.set_ylim([y_lower_lim, y_upper_lim])
#yticks = np.arange(stop=without_pretraining[0], step=0.02)
yticks = np.arange(stop=y_upper_lim + step, step=step)
ax.tick_params(axis='y', labelsize=14)
ax.set_yticks(yticks)
#ax.set_ylim([-0.6, 1.0])
#ax.tick_params(axis='y', which="both", bottom=False, top=False, left=False, right=False, labelbottom=False, labeltop=False, labelleft=False, labelright=False)

# Grid + legend
ax.grid(True, linestyle="--", alpha=0.6, axis="y")
ax.legend(prop={"family": font["fontname"], "size": 14})

plt.tight_layout()
plt.show()

In [None]:
# Data
N_fine = np.array([10000, 20000, 40000, 80000, 120000])

data_homo = {
    "Test": {
        "with": np.array([0.0582, 0.0487, 0.0413, 0.0387, 0.0380]),
        "with2": np.array([0.0507, 0.0449, 0.0424, 0.0348, 0.0348]),
        "with3": np.array([0.0504, 0.0472, 0.0428, 0.0385, 0.0435]),
        "without": np.array([0.1476, 0.1122, 0.0848, 0.0590, 0.0518]),
        "without2": np.array([0.1433, 0.1062, 0.0712, 0.0422, 0.0510]),
        "without3": np.array([0.1459, 0.1063, 0.0738, 0.0447, 0.0449]),
    },
    "PC9": {
        "with": np.array([0.0717, 0.0614, 0.0630, 0.0587, 0.0632]),
        "with2": np.array([0.0634, 0.0616, 0.0581, 0.0601, 0.0617]),
        "with3": np.array([0.0627, 0.0595, 0.0578, 0.0617, 0.0611]),
        "without": np.array([0.3693, 0.3231, 0.2735, 0.2220, 0.2074]),
        "without2": np.array([0.4137, 0.3628, 0.3031, 0.2354, 0.2085]),
        "without3": np.array([0.4303, 0.3633, 0.3067, 0.2284, 0.1944]),
    },
    "OE62L": {
        "with": np.array([0.1034, 0.1060, 0.0848, 0.0922, 0.0866]),
        "with2": np.array([0.1147, 0.1025, 0.0896, 0.1032, 0.0883]),
        "with3": np.array([0.1120, 0.0912, 0.0892, 0.0836, 0.0836]),
        "without": np.array([0.4414, 0.3918, 0.3955, 0.3521, 0.3751]),
        "without2": np.array([0.4072, 0.4288, 0.3425, 0.2843, 0.2209]),
        "without3": np.array([0.3863, 0.3755, 0.3070, 0.2492, 0.2433])
    },
    "OE62H": {
        "with": np.array([0.1149, 0.1247, 0.1024, 0.1079, 0.1041]),
        "with2": np.array([0.1399, 0.1246, 0.1108, 0.1263, 0.1091]),
        "with3": np.array([0.1244, 0.1023, 0.0995, 0.1042, 0.0969]),
        "without": np.array([0.6500, 0.6346, 0.5514, 0.4913, 0.4833]),
        "without2": np.array([0.7210, 0.6571, 0.6582, 0.5552, 0.5221]),
        "without3": np.array([0.7351, 0.6760, 0.6360, 0.5820, 0.5432]),
    },
}

data_gap = {
    "Test": {
        "with": np.array([0.0818, 0.0758, 0.0617, 0.0566, 0.0513]),
        "with2": np.array([0.0745, 0.0715, 0.0638, 0.0479, 0.0536]),
        "with3": np.array([0.0796, 0.0713, 0.0576, 0.0531, 0.0491]),
        "without": np.array([0.1954, 0.1508, 0.1126, 0.0852, 0.0658]),
        "without2": np.array([0.1645, 0.1228, 0.0876, 0.0504, 0.0713]),
        "without3": np.array([0.1688, 0.1229, 0.0883, 0.0584, 0.0621]),
    },
    "PC9": {
        "with": np.array([0.1048, 0.0993, 0.0997, 0.0975, 0.0951]),
        "with2": np.array([0.1013, 0.0967, 0.0934, 0.0995, 0.0985]),
        "with3": np.array([0.1010, 0.0956, 0.0915, 0.0953, 0.0977]),
        "without": np.array([0.4708, 0.4017, 0.3596, 0.3081, 0.2696]),
        "without2": np.array([0.4947, 0.4199, 0.3577, 0.2771, 0.2776]),
        "without3": np.array([0.4531, 0.3863, 0.3302, 0.2873, 0.2589]),
    },
    "OE62L": {
        "with": np.array([0.2576, 0.2433, 0.2265, 0.2223, 0.2322]),
        "with2": np.array([0.2314, 0.2441, 0.2177, 0.2341, 0.2366]),
        "with3": np.array([0.2535, 0.2391, 0.2154, 0.2199, 0.2176]),
        "without": np.array([0.7663, 0.7301, 0.5581, 0.5146, 0.5245]),
        "without2": np.array([0.7058, 0.4988, 0.5382, 0.3844, 0.4513]),
        "without3": np.array([0.5640, 0.5285, 0.5105, 0.4504, 0.4647]),
    },
    "OE62H": {
        "with": np.array([0.3109, 0.2924, 0.2736, 0.2478, 0.2559]),
        "with2": np.array([0.2727, 0.2900, 0.2527, 0.2620, 0.2522]),
        "with3": np.array([0.3019, 0.2833, 0.2460, 0.2456, 0.2404]),
        "without": np.array([0.7555, 0.6693, 0.5933, 0.5352, 0.5187]),
        "without2": np.array([0.8051, 0.6951, 0.6836, 0.5643, 0.5102]),
        "without3": np.array([0.7060, 0.6513, 0.6020, 0.5179, 0.5123]),
    },
}

data_lumo = {
    "Test": {
        "with": np.array([0.0481, 0.0416, 0.0369, 0.0324, 0.0321]),
        "with2": np.array([0.0486, 0.0410, 0.0394, 0.0302, 0.0298]),
        "with3": np.array([0.0468, 0.0405, 0.0352, 0.0301, 0.0291]),
        "without": np.array([0.0893, 0.0749, 0.0569, 0.0364, 0.0338]),
        "without2": np.array([0.0918, 0.0656, 0.0449, 0.0325, 0.0271]),
        "without3": np.array([0.0842, 0.0646, 0.0488, 0.0228, 0.0321]),
    },
    "PC9": {
        "with": np.array([0.0638, 0.0585, 0.0540, 0.0539, 0.0530]),
        "with2": np.array([0.0631, 0.0574, 0.0548, 0.0518, 0.0522]),
        "with3": np.array([0.0628, 0.0560, 0.0540, 0.0522, 0.0511]),
        "without": np.array([0.1946, 0.1859, 0.1571, 0.1302, 0.1258]),
        "without2": np.array([0.1918, 0.1728, 0.1485, 0.1338, 0.1180]),
        "without3": np.array([0.1954, 0.1816, 0.1551, 0.1210, 0.1089]),
    },
    "OE62L": {
        "with": np.array([0.1353, 0.1220, 0.1249, 0.1155, 0.1172]),
        "with2": np.array([0.1367, 0.1312, 0.1120, 0.1112, 0.1010]),
        "with3": np.array([0.1365, 0.1201, 0.1128, 0.1106, 0.1035]),
        "without": np.array([0.3743, 0.3521, 0.5381, 0.2970, 0.3075]),
        "without2": np.array([0.3381, 0.3014, 0.2576, 0.2864, 0.2720]),
        "without3": np.array([0.3638, 0.2991, 0.2761, 0.2560, 0.2550]),
    },
    "OE62H": {
        "with": np.array([0.1692, 0.1575, 0.1616, 0.1468, 0.1507]),
        "with2": np.array([0.1636, 0.1675, 0.1519, 0.1485, 0.1368]),
        "with3": np.array([0.1640, 0.1518, 0.1471, 0.1446, 0.1392]),
        "without": np.array([0.3766, 0.3884, 0.4218, 0.3506, 0.3530]),
        "without2": np.array([0.3795, 0.3600, 0.3378, 0.3751, 0.3489]),
        "without3": np.array([0.3921, 0.3621, 0.3471, 0.3371, 0.3274]),
    },
}

data = data_gap
data_set = "OE62H"
with_pretraining = data[data_set]["with"]
without_pretraining = data[data_set]["without"]

#'''
with_pretraining_2 = data[data_set]["with2"]
with_pretraining_3 = data[data_set]["with3"]
without_pretraining_2 = data[data_set]["without2"]
without_pretraining_3 = data[data_set]["without3"]

with_pretraining = (with_pretraining + with_pretraining_2 + with_pretraining_3) / 3
without_pretraining = (without_pretraining + without_pretraining_2 + without_pretraining_3) / 3
#'''

with_pretraining -= with_pretraining[-1]
#with_pretraining /= with_pretraining[0]
without_pretraining -= without_pretraining[-1]
#without_pretraining /= without_pretraining[0]

cmap = plt.cm.Blues
color1, color2 = cmap(0.5), cmap(1.0)

# Plot
fig, ax = plt.subplots(figsize=(8, 6))
font = {"fontname": "Helvetica", "fontsize": 20}

ax.plot(N_fine, without_pretraining, label="Without Pretraining", color=color1, marker="s", markersize=8)
ax.plot(N_fine, with_pretraining, label="With Pretraining", color=color2, marker="o", markersize=8)

ax.set_title(data_set, fontname="Helvetica", fontsize=25)

ax.set_xlabel("Number of Finetuning Samples", **font)
ax.set_xlim([10000, 120000])
xticks = N_fine
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, fontname="Helvetica", fontsize=14)

y_lower_lim = -0.05
y_upper_lim = 0.25
step = 0.05
ax.set_ylabel("$\Delta$MAE [eV]", **font)
#ax.set_ylim([-0.005, without_pretraining[0]+0.01])
ax.set_ylim([y_lower_lim, y_upper_lim])
#yticks = np.arange(stop=without_pretraining[0], step=0.02)
yticks = np.arange(stop=y_upper_lim + step, step=step)
ax.tick_params(axis='y', labelsize=14)
ax.set_yticks(yticks)
#ax.set_ylim([-0.6, 1.0])
#ax.tick_params(axis='y', which="both", bottom=False, top=False, left=False, right=False, labelbottom=False, labeltop=False, labelleft=False, labelright=False)

# Grid + legend
ax.grid(True, linestyle="--", alpha=0.6, axis="y")
ax.legend(prop={"family": font["fontname"], "size": 14})

plt.tight_layout()
plt.show()

### Small model

In [None]:
font = {"fontname": "Helvetica", "fontsize": 14}

x = [10000, 20000, 40000, 80000, 120000]
y_gwset = []
y_pc9 = []
plt.scatter(x, y_gwset, label="Test")
plt.scatter(x, y_pc9, label="PC9")
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000])
plt.xlim([0, 125000])
plt.xlabel("Number of Training Samples", **font)
plt.ylabel("MAE [eV]", **font)
plt.legend(fontsize=12)
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

### Large model

In [None]:
font = {"fontname": "Helvetica", "fontsize": 14}

x = [10000, 20000, 40000, 80000, 120000]
y_gwset = [0.1476, 0.1122, 0.0825, 0.0590, 0.0518]
y_pc9 = [0.3693, 0.3231, 0.2744, 0.2220, 0.2074]
y_pre_gwset = [0.0380]
y_pr_pc9 = [0.0632]
plt.scatter(x, y_gwset, label="Test")
plt.scatter(x, y_pc9, label="PC9")
plt.xticks([0, 20000, 40000, 60000, 80000, 100000, 120000])
plt.xlim([0, 125000])
plt.xlabel("Number of Training Samples", **font)
plt.ylabel("MAE [eV]", **font)
plt.legend(fontsize=12)
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

# Pretraining and Finetuning Alignment

In [None]:
font = {"fontname": "Helvetica", "fontsize": 20}      

data_homo = {
    "Test": [0.0518, 0.0380, 0.0809, 0.0506],
    "PC9": [0.2074, 0.0632, 0.2215, 0.0910],
    "OE62L": [0.3751, 0.0866, 0.3004, 0.1179],
    "OE62H": [0.4833, 0.1041, 0.3971, 0.1330]
}

data_lumo = {
    "Test": [0.0338, 0.0577, 0.0321, 0.0370],
    "PC9": [0.1258, 0.1331, 0.0530, 0.0755],
    "OE62L": [0.3075, 0.4415, 0.1172, 0.2490],
    "OE62H": [0.3530, 0.3776, 0.1507, 0.1963]
}

data_gap = {
    "Test": [0.0658, 0.0684, 0.0867, 0.0513],
    "PC9": [0.2696, 0.1505, 0.2374, 0.0951],
    "OE62L": [0.5245, 0.3452, 0.3177, 0.2322],
    "OE62H": [0.5187, 0.4097, 0.4287, 0.2559]
}

target = "lumo"

if target == "homo":
    data = data_homo
elif target == "lumo":
    data = data_lumo
elif target == "gap":
    data = data_gap

bar_categories = ["None", "Homo", "Lumo", "Gap"]
categories = ["Test", "PC9", "OE62L", "OE62H"]

# Gradient colors
cmap = plt.cm.Blues
colors = [cmap(i) for i in np.linspace(0.35, 1.05, len(bar_categories))]

bar_width = 0.35

# Group widths
W_group = 4*bar_width   # width of groups with 2 bars
W_ref   = bar_width                 # single bar for Reference
S_groups = 1.0

# X positions
x_groups = np.array([1, 3, 5, 7], dtype=float)
x_ref = x_groups[0] - (S_groups - (W_group - W_ref)/2.0)

# Offsets for 2 bars inside a group
offsets = np.array([
    -W_group/2 + bar_width/2,
    -W_group/2 + bar_width/2 + bar_width,
     W_group/2 - bar_width/2 - bar_width,
     W_group/2 - bar_width/2
])

fig, ax = plt.subplots(figsize=(8, 6))

# --- Plot Reference (single bar) ---
'''
ax.bar(
    x_ref,
    data["Reference"][0],
    width=bar_width,
    label=f"{N_pre_values[0]}",
    color=colors[0]
)
'''

# --- Plot other categories (2 bars each) ---
for i, (b_cat, color) in enumerate(zip(bar_categories, colors)):  # i = 0..1
    vals = [data[cat][i] for cat in categories]
    ax.bar(
        x_groups + offsets[i],
        vals,
        width=bar_width,
        label=f"{b_cat}",
        color=color
    )

# Axes, ticks, legend
ax.set_title(target.title(), fontname="Helvetica", fontsize=25)
ax.set_xlabel("Dataset", **font)
ax.set_ylabel("MAE [eV]", **font)
#xticks = np.concatenate(([x_ref], x_groups))
xticks = x_groups
ax.set_xticks(xticks)
ax.set_xticklabels(categories, fontname="Helvetica", fontsize=14)
yticks = np.arange(stop=0.65, step=0.1)  # adjust depending on dataset
ax.set_yticks(yticks)
ax.tick_params(axis='y', labelsize=14)
ax.set_ylim(0, yticks[-1])

legend = ax.legend(title="Pretraining Targets", prop={"family": font["fontname"], "size": 14})
legend.get_title().set_fontsize(14)

plt.grid(True, axis="y", linestyle="--", linewidth=0.5, alpha=0.7)
plt.tight_layout()
plt.show()

# Prediction MAE on QM9GW vs. on OMol25

In [None]:
test_set = "PC9"
target = "homo"

logs_path = f"/Volumes/LaCie/trained_models/ViSNet/{target}0M128502532"
ckpt_name = "model_45_epochs"
model = load_model(logs_path, ckpt_name)

logs_path = f"/Volumes/LaCie/trained_models/ViSNet/{target}5M128502532_transfer_converged/visnet_logs"
ckpt_name = "model_30_epochs"
pretrained_model = load_model(logs_path, ckpt_name)

xyz_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "mols")
gw_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "E_qp")
omol25_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "E_omol25")

xyz_files = os.listdir(xyz_path)
all_mae_gw = []
all_mae_omol25 = []
all_mae_gw_pretrain = []
all_mae_omol25_pretrain = []
with torch.no_grad():
    for xyz_file in tqdm(xyz_files, leave=False):
        mol = xyz_file[:-4]
        atoms = read(f"{xyz_path}/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_gw = np.loadtxt(f"{gw_path}/{mol}.dat")
        e_omol25 = np.loadtxt(f"{omol25_path}/{mol}.dat")
        if target == "homo":
            y_gw = float(e_gw[homo_idx])
            y_omol25 = float(e_omol25[0])
        elif target == "lumo":
            y_gw = float(e_gw[homo_idx + 1])
            y_omol25 = float(e_omol25[1])
        elif target == "gap":
            y_gw = float(e_gw[homo_idx + 1]) - float(e_gw[homo_idx])
            y_omol25 = float(e_omol25[1]) - float(e_omol25[0])
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        y_pred_pretrain, _ = pretrained_model(data)
        y_pred_pretrain = y_pred_pretrain.item()
        all_mae_gw.append(abs(y_pred - y_gw))
        all_mae_omol25.append(abs(y_pred - y_omol25))
        all_mae_gw_pretrain.append(abs(y_pred_pretrain - y_gw))
        all_mae_omol25_pretrain.append(abs(y_pred_pretrain - y_omol25))

In [None]:
# Create figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Font dictionary
font = {"fontname": "Helvetica", "fontsize": 20}
cmap = plt.cm.Blues

# Scatter plots
ax.scatter(all_mae_gw, all_mae_omol25, label="Without Pretraining", s=30, color=cmap(0.5))
ax.scatter(all_mae_gw_pretrain, all_mae_omol25_pretrain, label="With Pretraining", s=30, color=cmap(0.9))

# Title
ax.set_title("PC9", fontname="Helvetica", fontsize=25)

# X-axis
ax.set_xlabel("MAE QM9GW [eV]", **font)
ax.set_xlim([0, 2])
xticks = np.arange(0, 2.2, 0.2)
xticks = [float(f"{x:.1f}") for x in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticks, fontsize=14)

# Y-axis
ax.set_ylabel("MAE OMol25 [eV]", **font)
ax.set_ylim([0, 2])
yticks = np.arange(0, 2.2, 0.2)
ax.set_yticks(yticks)
ax.tick_params(axis="y", labelsize=14)

# Legend
ax.legend(prop={"family": font["fontname"], "size": 14}, loc="upper right",  frameon=True, framealpha=1, edgecolor="black")

# Layout
plt.tight_layout()
plt.show()


# Trial Plots

In [None]:
# definitions:
#
# case 1_1: y_gwset < y_omol25 < y_pred
# case 1_2: y_pred < y_omol25 < y_gwset
#
# case 2_1: y_omol25 < y_gwset < y_pred
# case 2_2: y_pred < y_gwset < y_omol25
#
# case 3_1: y_gwset < y_pred < y_omol25
# case 3_2: y_omol25 < y_pred < y_gwset

target = "homo"
lacie_conn = os.path.exists("/Volumes/LaCie")
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")

mae_gwset_case_1_1 = []
mae_omol25_case_1_1 = []
mae_gwset_case_1_2 = []
mae_omol25_case_1_2 = []

mae_gwset_case_2_1 = []
mae_omol25_case_2_1 = []
mae_gwset_case_2_2 = []
mae_omol25_case_2_2 = []

mae_gwset_case_3_1 = []
mae_omol25_case_3_1 = []
mae_gwset_case_3_2 = []
mae_omol25_case_3_2 = []

with torch.no_grad():
    for xyz_file in xyz_files:
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[0])
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[1])
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        mae_gw = abs(y_pred - y_gw)
        mae_dft = abs(y_pred - y_dft)
        if y_gw < y_dft < y_pred:
            mae_gwset_case_1_1.append(mae_gw)
            mae_omol25_case_1_1.append(mae_dft)
        elif y_pred < y_dft < y_gw:
            mae_gwset_case_1_2.append(mae_gw)
            mae_omol25_case_1_2.append(mae_dft)
        elif y_dft < y_gw < y_pred:
            mae_gwset_case_2_1.append(mae_gw)
            mae_omol25_case_2_1.append(mae_dft)
        elif y_pred < y_gw < y_dft:
            mae_gwset_case_2_2.append(mae_gw)
            mae_omol25_case_2_2.append(mae_dft)
        elif y_gw < y_pred < y_dft:
            mae_gwset_case_3_1.append(mae_gw)
            mae_omol25_case_3_1.append(mae_dft)
        elif y_dft < y_pred < y_gw:
            mae_gwset_case_3_2.append(mae_gw)
            mae_omol25_case_3_2.append(mae_dft)
    if lacie_conn:
        pc9_path = "/Volumes/LaCie/test_sets/PC9"
        xyz_files = os.listdir(f"{pc9_path}/mols")
        for xyz_file in xyz_files:
            mol = xyz_file[:-4]
            atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
            e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
            e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
            if target == "homo":
                y_gw = float(e_qp[homo_idx])
                y_dft = float(e_dft[0])
            elif target == "lumo":
                y_gw = float(e_qp[homo_idx+1])
                y_dft = float(e_dft[1])
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            mae_gw = abs(y_pred - y_gw)
            mae_dft = abs(y_pred - y_dft)
            if y_gw < y_dft < y_pred:
                mae_gwset_case_1_1.append(mae_gw)
                mae_omol25_case_1_1.append(mae_dft)
            elif y_pred < y_dft < y_gw:
                mae_gwset_case_1_2.append(mae_gw)
                mae_omol25_case_1_2.append(mae_dft)
            elif y_dft < y_gw < y_pred:
                mae_gwset_case_2_1.append(mae_gw)
                mae_omol25_case_2_1.append(mae_dft)
            elif y_pred < y_gw < y_dft:
                mae_gwset_case_2_2.append(mae_gw)
                mae_omol25_case_2_2.append(mae_dft)
            elif y_gw < y_pred < y_dft:
                mae_gwset_case_3_1.append(mae_gw)
                mae_omol25_case_3_1.append(mae_dft)
            elif y_dft < y_pred < y_gw:
                mae_gwset_case_3_2.append(mae_gw)
                mae_omol25_case_3_2.append(mae_dft)

'''
data = {
    "mol": all_mols,
    "y_pred": all_y_pred,
    "y_gw": all_y_gw,
    "mae_gw": all_mae_gw,
    "y_dft": all_y_dft,
    "mae_dft": all_mae_dft,
}

df = pd.DataFrame(data)

print(f"Mean = {np.mean(all_mae_gw):.4f} +-({np.std(all_mae_gw):.4f}) eV")
print(f"Median = {np.median(all_mae_gw):.4f} eV")
print(f"Min. MAE = {min_mae_gw:.4f} eV for {min_mol_gw}")
print(f"Max. MAE = {max_mae_gw:.4f} eV for {max_mol_gw}")
'''

In [None]:
plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, c="#7aa457") # y_gwset < y_omol25 < y_pred
plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, c="#496234") # y_pred < y_omol25 < y_gwset
plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, c="#9e6ebd") # y_omol25 < y_gwset < y_pred
plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, c="#5f4271") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, c="#cb6751") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, c="#7a3d31") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.show()

In [None]:
#plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, label="case 1") # y_gwset < y_omol25 < y_pred
#plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, label="case 2") # y_pred < y_omol25 < y_gwset
#plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, label="case 1") # y_omol25 < y_gwset < y_pred
#plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, label="case 2") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, label="case 1") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, label="case 2") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, c="#7aa457") # y_gwset < y_omol25 < y_pred
plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, c="#496234") # y_pred < y_omol25 < y_gwset
plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, c="#9e6ebd") # y_omol25 < y_gwset < y_pred
plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, c="#5f4271") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, c="#cb6751") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, c="#7a3d31") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.show()

In [None]:
#plt.scatter(mae_gwset_case_1_1, mae_omol25_case_1_1, label="case 1") # y_gwset < y_omol25 < y_pred
#plt.scatter(mae_gwset_case_1_2, mae_omol25_case_1_2, label="case 2") # y_pred < y_omol25 < y_gwset
#plt.scatter(mae_gwset_case_2_1, mae_omol25_case_2_1, label="case 1") # y_omol25 < y_gwset < y_pred
#plt.scatter(mae_gwset_case_2_2, mae_omol25_case_2_2, label="case 2") # y_pred < y_gwset < y_omol25
plt.scatter(mae_gwset_case_3_1, mae_omol25_case_3_1, label="case 1") # y_gwset < y_pred < y_omol25
plt.scatter(mae_gwset_case_3_2, mae_omol25_case_3_2, label="case 2") # y_omol25 < y_pred < y_gwset
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
#logs_path = "../visnet_logs"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo0M216502532"
ckpt_name = "best_model"
#ckpt_name = "model_10_epochs"

model = load_model(logs_path, ckpt_name)

#logs_path = "../visnet_logs"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/homo10M216502532_best/visnet_logs"
#ckpt_name = "best_model"
ckpt_name = "model_10_epochs"

pretrained_model = load_model(logs_path, ckpt_name)

In [None]:
plt.scatter(all_mae_gwset, all_mae_omol25, label="Without Pretraining", s=15.0)
plt.scatter(all_mae_gwset_pretrain, all_mae_omol25_pretrain, label="With Pretraining", s=15.0)
plt.xlabel("MAE GWSet [eV]")
plt.ylabel("MAE OMol25 [eV]")
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.legend()
plt.show()

In [None]:
logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/prehomo10M216502532"
#ckpt_name = "best_model"
ckpt_name = "model_5_epochs"

target = "gap"
lacie_conn = os.path.exists("/Volumes/LaCie")
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")
model = load_model(logs_path, ckpt_name)

all_target = []
all_pred = []
all_mae_pc9 = []
all_mae_omol25 = []

with torch.no_grad():
    for xyz_file in xyz_files:
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[0])
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[1])
        elif target == "gap":
            homo_gw = float(e_qp[homo_idx])
            homo_dft = float(e_dft[0])
            lumo_gw = float(e_qp[homo_idx+1])
            lumo_dft = float(e_dft[1])
            y_gw = lumo_gw - homo_gw
            y_dft= lumo_dft - homo_dft
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        mae_gw = abs(y_pred - y_gw)
        mae_dft = abs(y_pred - y_dft)
        all_mae_pc9.append(mae_gw)
        all_mae_omol25.append(mae_dft)
        all_target.append(y_gw)
        all_pred.append(y_pred)
    if lacie_conn:
        pc9_path = "/Volumes/LaCie/test_sets/PC9"
        xyz_files = os.listdir(f"{pc9_path}/mols")
        for xyz_file in xyz_files:
            mol = xyz_file[:-4]
            atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
            e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
            e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
            if target == "homo":
                y_gw = float(e_qp[homo_idx])
                y_dft = float(e_dft[0])
            elif target == "lumo":
                y_gw = float(e_qp[homo_idx+1])
                y_dft = float(e_dft[1])
            elif target == "gap":
                homo_gw = float(e_qp[homo_idx])
                homo_dft = float(e_dft[0])
                lumo_gw = float(e_qp[homo_idx+1])
                lumo_dft = float(e_dft[1])
                y_gw = lumo_gw - homo_gw
                y_dft= lumo_dft - homo_dft
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            mae_gw = abs(y_pred - y_gw)
            mae_dft = abs(y_pred - y_dft)
            all_mae_pc9.append(mae_gw)
            all_mae_omol25.append(mae_dft)
            all_target.append(y_gw)
            all_pred.append(y_pred)

In [None]:
plt.hist(all_target, bins=50, density=True, label="Target")
plt.hist(all_pred, bins=50, density=True, label="Prediction")
plt.legend()
plt.show()

In [None]:
plt.hist(all_target, bins=50, density=True, label="Target")
plt.hist(all_pred, bins=50, density=True, label="Prediction")
plt.legend()
plt.show()