In [None]:
import os
import torch

import numpy as np
import pandas as pd

from omegaconf import OmegaConf
from ase.io import read
from ViSNetGW.model.visnet import create_model
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import rdDetermineBonds

In [None]:
HARTREE_TO_EV = 27.2114

In [None]:
#logs = "/Volumes/LaCie/trained_models/ViSNet/homo1M128502532"
logs = "../visnet_logs"

model_file = f"{logs}/best_model.ckpt"
#model_file = f"{logs}/model_40_epochs.ckpt"
config_file = f"{logs}/config.yaml"

In [None]:
cfg = OmegaConf.load(os.path.join(config_file))
model = create_model(cfg)
state_dict = torch.load(model_file, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
model.eval()

# PC9

### HOMO / LUMO

In [None]:
target = "homo"
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")

all_mols = []
all_y_pred = []

max_mol_gw = None
min_mol_gw = None
all_mae_gw = []
min_mae_gw = 1000000
max_mae_gw = 0
all_y_gw = []

max_mol_dft = None
min_mol_dft = None
all_mae_dft = []
min_mae_dft = 1000000
max_mae_dft = 0
all_y_dft = []

with torch.no_grad():
    for xyz_file in xyz_files:
        raw_mol = Chem.MolFromXYZFile(f"{pc9_path}/mols/{xyz_file}")
        mol = Chem.Mol(raw_mol)
        rdDetermineBonds.DetermineBonds(mol)
        for atom in mol.GetAtoms():
            if atom.GetFormalCharge() != 0:
                continue
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_omol25/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[0])
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[1])
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        mae_gw = abs(y_pred - y_gw)
        mae_dft = abs(y_pred - y_dft)
        all_y_pred.append(y_pred)
        all_y_gw.append(y_gw)
        all_y_dft.append(y_dft)
        all_mae_gw.append(mae_gw)
        all_mae_dft.append(mae_dft)
        all_mols.append(mol)
        if mae_gw < min_mae_gw:
            min_mae_gw = mae_gw
            min_mol_gw = mol
        if mae_gw > max_mae_gw:
            max_mae_gw = mae_gw
            max_mol_gw = mol
        if mae_dft < min_mae_dft:
            min_mae_dft = mae_dft
            min_mol_dft = mol
        if mae_dft > max_mae_dft:
            max_mae_dft = mae_dft
            max_mol_dft = mol

data = {
    "mol": all_mols,
    "y_pred": all_y_pred,
    "y_gw": all_y_gw,
    "mae_gw": all_mae_gw,
    "y_dft": all_y_dft,
    "mae_dft": all_mae_dft,
}

df = pd.DataFrame(data)

print("qsGW")
print(f"Mean = {np.mean(all_mae_gw):.4f} +-({np.std(all_mae_gw):.4f}) eV")
print(f"Median = {np.median(all_mae_gw):.4f} eV")
print(f"Min. MAE = {min_mae_gw:.4f} eV for {min_mol_gw}")
print(f"Max. MAE = {max_mae_gw:.4f} eV for {max_mol_gw}")
print()
print("DFT")
print(f"Mean = {np.mean(all_mae_dft):.4f} +-({np.std(all_mae_dft):.4f}) eV")
print(f"Median = {np.median(all_mae_dft):.4f} eV")
print(f"Min. MAE = {min_mae_dft:.4f} eV for {min_mol_dft}")
print(f"Max. MAE = {max_mae_dft:.4f} eV for {max_mol_dft}")

In [None]:
# reines finetuning
df.nlargest(n=5, columns="mae_gw")

In [None]:
# reines finetuning (best model)
df.nlargest(n=5, columns="mae_gw")

In [None]:
# pretraining + finetuning (epoch 30)
df.nlargest(n=5, columns="mae_gw")

In [None]:
# pretraining + finetuning (epoch 25)
df.nlargest(n=5, columns="mae_gw")

In [None]:
df.hist(column="mae", bins=50)

In [None]:
df.hist(column="mae", bins=50)

### GAP

In [None]:
pc9_path = "../test_datasets/PC9"
logs_homo = "../visnet_logs_homo"
logs_lumo = "../visnet_logs_lumo"

model_file_homo = f"{logs_homo}/best_model.ckpt"
config_file_homo = f"{logs_homo}/config.yaml"
model_file_lumo = f"{logs_lumo}/best_model.ckpt"
config_file_lumo = f"{logs_lumo}/config.yaml"

cfg = OmegaConf.load(os.path.join(config_file_homo))
model_homo = create_model(cfg)
state_dict = torch.load(model_file_homo, map_location=torch.device("cpu"))
model_homo.load_state_dict(state_dict)
model_homo.eval()
cfg = OmegaConf.load(os.path.join(config_file_lumo))
model_lumo = create_model(cfg)
state_dict = torch.load(model_file_lumo, map_location=torch.device("cpu"))
model_lumo.load_state_dict(state_dict)
model_lumo.eval()

xyz_files = os.listdir(f"{pc9_path}/mols")
all_mols = []
all_mae = []
min_mae = 1000000
max_mae = 0
with torch.no_grad():
    for xyz_file in xyz_files:
        raw_mol = Chem.MolFromXYZFile(f"{pc9_path}/mols/{xyz_file}")
        mol = Chem.Mol(raw_mol)
        rdDetermineBonds.DetermineBonds(mol)
        for atom in mol.GetAtoms():
            if atom.GetFormalCharge() != 0:
                continue
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        eqp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        y = float(eqp[homo_idx+1]) - float(eqp[homo_idx])
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred_homo, _ = model_homo(data)
        y_pred_lumo, _ = model_lumo(data)
        y_pred = y_pred_lumo.item() - y_pred_homo.item()
        print(f"{y_pred:.4f} {y:.4f}")
        mae = abs(y_pred - y)
        all_mae.append(mae)
        all_mols.append(mol)
        if mae < min_mae:
            min_mae = mae
            min_mol = mol
        if mae > max_mae:
            max_mae = mae
            max_mol = mol

data = {
    "mol": all_mols,
    "mae": all_mae,
}

df = pd.DataFrame(data)

print(f"Mean = {np.mean(all_mae):.4f} +-({np.std(all_mae):.4f}) eV")
print(f"Median = {np.median(all_mae):.4f} eV")
print(f"Min. MAE = {min_mae:.4f} eV for {min_mol}")
print(f"Max. MAE = {max_mae:.4f} eV for {max_mol}")

### $\Delta HOMO$

In [None]:
target = "homo"
pc9_path = "../test_datasets/PC9"

xyz_files = os.listdir(f"{pc9_path}/mols")

all_mols = []
all_y_pred = []

max_mol_gw = None
min_mol_gw = None
all_mae_gw = []
min_mae_gw = 1000000
max_mae_gw = 0
all_y_gw = []

max_mol_dft = None
min_mol_dft = None
all_mae_dft = []
min_mae_dft = 1000000
max_mae_dft = 0
all_y_dft = []

with torch.no_grad():
    for xyz_file in xyz_files:
        raw_mol = Chem.MolFromXYZFile(f"{pc9_path}/mols/{xyz_file}")
        mol = Chem.Mol(raw_mol)
        rdDetermineBonds.DetermineBonds(mol)
        for atom in mol.GetAtoms():
            if atom.GetFormalCharge() != 0:
                continue
        mol = xyz_file[:-4]
        atoms = read(f"{pc9_path}/mols/{xyz_file}", format="xyz")
        homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
        e_qp = np.loadtxt(f"{pc9_path}/E_qp/{mol}.dat")
        e_dft = np.loadtxt(f"{pc9_path}/E_dft/{mol}.dat")
        if target == "homo":
            y_gw = float(e_qp[homo_idx])
            y_dft = float(e_dft[homo_idx])
            y = y_gw
        elif target == "lumo":
            y_gw = float(e_qp[homo_idx+1])
            y_dft = float(e_dft[homo_idx+1])
            y = y_gw - y_dft
        Z = torch.from_numpy(atoms.get_atomic_numbers())
        R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
        B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
        data = {"z": Z, "pos": R, "batch": B}
        y_pred, _ = model(data)
        y_pred = y_pred.item()
        print(y_pred, y_dft, y)
        mae_gw = abs(y_pred + y_dft - y)
        mae_dft = abs(y_pred - y_dft)
        all_y_pred.append(y_pred)
        all_y_gw.append(y_gw)
        all_y_dft.append(y_dft)
        all_mae_gw.append(mae_gw)
        all_mae_dft.append(mae_dft)
        all_mols.append(mol)
        if mae_gw < min_mae_gw:
            min_mae_gw = mae_gw
            min_mol_gw = mol
        if mae_gw > max_mae_gw:
            max_mae_gw = mae_gw
            max_mol_gw = mol
        if mae_dft < min_mae_dft:
            min_mae_dft = mae_dft
            min_mol_dft = mol
        if mae_dft > max_mae_dft:
            max_mae_dft = mae_dft
            max_mol_dft = mol

data = {
    "mol": all_mols,
    "y_pred": all_y_pred,
    "y_gw": all_y_gw,
    "mae_gw": all_mae_gw,
    "y_dft": all_y_dft,
    "mae_dft": all_mae_dft,
}

df = pd.DataFrame(data)

print(f"Mean = {np.mean(all_mae_gw):.4f} +-({np.std(all_mae_gw):.4f}) eV")
print(f"Median = {np.median(all_mae_gw):.4f} eV")
print(f"Min. MAE = {min_mae_gw:.4f} eV for {min_mol_gw}")
print(f"Max. MAE = {max_mae_gw:.4f} eV for {max_mol_gw}")

# Large

In [None]:
large_path = "../test_datasets/Large"
xyz_path = f"{large_path}/mols"
eqp_path = f"{large_path}/E_qp"

In [None]:
mol = "SH3_guest"

In [None]:
atoms = read(f"{large_path}/mols/{mol}.xyz", format="xyz")
homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
eqp = np.loadtxt(f"{large_path}/E_qp/{mol}.dat")
y = float(eqp[homo_idx])

Z = torch.from_numpy(atoms.get_atomic_numbers())
R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
data = {"z": Z, "pos": R, "batch": B}
y_pred, _ = model(data)
y_pred = y_pred.item()

print(f"Error = {y_pred - y:.5f} eV")

# QM9

In [None]:
xyz_path = "/Users/dario/datasets/GWSet/QM9/QM9_xyz_files"
eqp_path = "/Users/dario/datasets/GWSet/results/E_qp"

xyz_files = os.listdir(xyz_path)
all_mols = []
all_y = []
all_y_pred = []
all_mae = []
min_mae = 1000000
max_mae = 0
for xyz_file in tqdm(xyz_files, leave=False):
    mol = xyz_file[:-4]
    atoms = read(f"{xyz_path}/{xyz_file}", format="xyz")
    homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
    eqp = np.loadtxt(f"{eqp_path}/{mol}.dat")
    y = float(eqp[homo_idx])
    Z = torch.from_numpy(atoms.get_atomic_numbers())
    R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
    B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
    data = {"z": Z, "pos": R, "batch": B}
    y_pred, _ = model(data)
    y_pred = y_pred.item()
    mae = abs(y_pred - y)
    all_mols.append(mol)
    all_y.append(y)
    all_y_pred.append(y_pred)
    all_mae.append(mae)
    if mae < min_mae:
        min_mae = mae
        min_mol = mol
    if mae > max_mae:
        max_mae = mae
        max_mol = mol

data = {
    "mol": all_mols,
    "y": all_y,
    "y_pred": all_y_pred,
    "mae": all_mae,
}

df = pd.DataFrame(data)
df.to_csv("gwset_homo_visnet.csv", sep=",", index=False)

In [None]:
df = pd.read_csv("gwset_homo_visnet.csv")

In [None]:
df = df.loc[df["mae"] <= 1.0]

In [None]:
df

In [None]:
df.nlargest(15, columns="mae")

In [None]:
df.hist(column="mae", bins=500)

In [None]:
df.nlargest(10, columns="mae")

In [None]:
xyz_path = "/Users/dario/datasets/GWSet/QM9/QM9_xyz_files"
eqp_path = "/Users/dario/datasets/GWSet/results/E_qp"
mol = "mol_133854"


atoms = read(f"{xyz_path}/{mol}.xyz", format="xyz")
homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
eqp = np.loadtxt(f"{eqp_path}/{mol}.dat")
y = float(eqp[homo_idx])
Z = torch.from_numpy(atoms.get_atomic_numbers())
R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
data = {"z": Z, "pos": R, "batch": B}
y_pred, _ = model(data)
y_pred = y_pred.item()
print(y, y_pred)