In [None]:
import os
import torch

import numpy as np

from omegaconf import OmegaConf
from ase.io import read
from ViSNetGW.model.visnet import create_model
from tqdm import tqdm
#from rdkit import Chem
#from rdkit.Chem import rdDetermineBonds

In [None]:
HARTREE_TO_EV = 27.2114

In [None]:
def load_model(logs_path, ckpt_name):
    model_file = f"{logs_path}/{ckpt_name}.ckpt"
    config_file = f"{logs_path}/config.yaml"
    cfg = OmegaConf.load(os.path.join(config_file))
    model = create_model(cfg)
    state_dict = torch.load(model_file, map_location=torch.device("cpu"))
    if "swa" in model_file:
        new_state_dict = {key[7:]: value for key, value in state_dict.items() if "module" in key}
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)
    model.eval()
    return model



def test_model(model, test_set, folder, target):

    xyz_path = os.path.join("/Volumes/LaCie/test_sets", test_set, "mols")
    data_path = os.path.join("/Volumes/LaCie/test_sets", test_set, folder)

    xyz_files = os.listdir(xyz_path)
    all_mae = []
    with torch.no_grad():
        for xyz_file in tqdm(xyz_files, leave=False):
            mol = xyz_file[:-4]
            atoms = read(os.path.join(xyz_path, xyz_file), format="xyz")
            if not (12 <= len(atoms) <= 24):
                continue
            #for elem in atoms.get_chemical_symbols():
            #    if elem not in ["H", "C", "N", "O", "F"]:
            #        print(elem)
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1) * (folder[:8] != "E_omol25")
            energies = np.loadtxt(os.path.join(data_path, f"{mol}.dat"))
            if target == "homo":
                y = float(energies[homo_idx])
            elif target == "lumo":
                y = float(energies[homo_idx + 1])
            elif target == "gap":
                homo = float(energies[homo_idx])
                lumo = float(energies[homo_idx + 1])
                y = lumo - homo
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            all_mae.append(abs(y_pred - y))
    
    print(f"Mean = {np.mean(all_mae):.4f} +-({np.std(all_mae):.4f}) eV")
    print(f"Median = {np.median(all_mae):.4f} eV")



def test_model_delta(model, target, base_path):

    xyz_path = os.path.join(base_path, "mols")
    gw_path = os.path.join(base_path, "E_qp")
    dft_path = os.path.join(base_path, "E_dft")

    xyz_files = os.listdir(xyz_path)
    all_mae = []
    with torch.no_grad():
        for xyz_file in tqdm(xyz_files, leave=False):
            mol = xyz_file[:-4]
            atoms = read(os.path.join(xyz_path, xyz_file), format="xyz")
            homo_idx = int(np.sum(atoms.get_atomic_numbers()) // 2 - 1)
            gw_energies = np.loadtxt(os.path.join(gw_path, f"{mol}.dat"))
            dft_energies = np.loadtxt(os.path.join(dft_path, f"{mol}.dat"))
            if target == "homo":
                y = float(gw_energies[homo_idx]) - float(dft_energies[homo_idx])
            elif target == "lumo":
                y = float(gw_energies[homo_idx + 1]) - float(dft_energies[homo_idx + 1])
            elif target == "gap":
                gw_homo = float(gw_energies[homo_idx])
                gw_lumo = float(gw_energies[homo_idx + 1])
                dft_homo = float(dft_energies[homo_idx])
                dft_lumo = float(dft_energies[homo_idx + 1])
                y = (gw_lumo - gw_homo) - (dft_lumo - dft_homo)
            Z = torch.from_numpy(atoms.get_atomic_numbers())
            R = torch.from_numpy(atoms.get_positions()).to(dtype=torch.float32)
            B = torch.zeros((len(atoms),)).to(dtype=torch.int64)
            data = {"z": Z, "pos": R, "batch": B}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            all_mae.append(abs(y_pred - y))
    
    print(f"Mean = {np.mean(all_mae):.4f} +-({np.std(all_mae):.4f}) eV")
    print(f"Median = {np.median(all_mae):.4f} eV")



def test_model_gwset(model, target):
    base_path = "/Volumes/LaCie/test_sets/GWSet"
    N = torch.load(os.path.join(base_path, "N.pt"))
    Z = torch.load(os.path.join(base_path, "Z.pt"))
    R = torch.load(os.path.join(base_path, "R.pt"))
    M = torch.load(os.path.join(base_path, "M.pt"))
    E = torch.load(os.path.join(base_path, f"{target}.pt"))

    all_mae = []
    with torch.no_grad():
        for i in tqdm(range(2885), leave=False):
            data = {"z": Z[i, M[i]], "pos": R[i, M[i], :], "batch": torch.tensor([0 for _ in range(N[i].item())])}
            y_pred, _ = model(data)
            y_pred = y_pred.item()
            y = E[i].item()
            #print(y_pred, y)
            all_mae.append(abs(y_pred - y))
    
    print(f"Mean = {np.mean(all_mae):.4f} +-({np.std(all_mae):.4f}) eV")
    print(f"Median = {np.median(all_mae):.4f} eV")

# Test

In [None]:
logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/prehomo10M216502532"
#ckpt_name = "best_model"
ckpt_name = "model_30_epochs"

model = load_model(logs_path, ckpt_name)
test_model_gwset(model, target="HOMO")

# PC9

In [None]:
logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/prehomo10M216502532"
#ckpt_name = "best_model"
ckpt_name = "model_30_epochs"

model = load_model(logs_path, ckpt_name)
test_model(model, test_set="PC9", folder="E_qp", target="homo")

# OE62

In [None]:
logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/prehomo10M216502532"
#ckpt_name = "best_model"
ckpt_name = "model_30_epochs"

model = load_model(logs_path, ckpt_name)
test_model(model, test_set="OE6218", folder="E_qp", target="homo")

# Elem

In [None]:
#logs_path = "../visnet_logs"
#logs_path = "/Volumes/LaCie/trained_models/ViSNet/lumo5M128502532_transfer_converged"
logs_path = "/Volumes/LaCie/trained_models/ViSNet/lumo0M128502532"
#ckpt_name = "best_model"
ckpt_name = "model_40_epochs"

model = load_model(logs_path, ckpt_name)
test_model(model, test_set="OE62H1", folder="E_qp", target="lumo")

In [None]:
xyz_path = "/Volumes/LaCie/test_sets/OE62H1/mols"
xyz_files = os.listdir(xyz_path)
for xyz_file in xyz_files:
    atoms = read(os.path.join(xyz_path, xyz_file), format="xyz")
    print(len(atoms))

In [None]:
import shutil

xyz_path = os.path.join("/Volumes/LaCie/test_sets/OE62", "old_mols")
new_xyz_path = os.path.join("/Volumes/LaCie/test_sets/OE62", "mols")
gw_path = os.path.join("/Volumes/LaCie/test_sets/OE62", "E_qp")
dft_path = os.path.join("/Volumes/LaCie/test_sets/OE62", "E_dft")

gw_files = os.listdir(gw_path)
dft_files = os.listdir(dft_path)
for gw_file in gw_files:
    mol = gw_file[:-4]
    shutil.copyfile(os.path.join(xyz_path, f"{mol}.xyz"), os.path.join(new_xyz_path, f"{mol}.xyz"))

In [None]:
import shutil
import numpy as np
from ase.io import read

xyz_path = "/Volumes/LaCie/test_sets/OE62/mols"
gw_path = "/Volumes/LaCie/test_sets/OE62/E_qp"
dft_path = "/Volumes/LaCie/test_sets/OE62/E_dft"

new_xyz_path = "/Volumes/LaCie/test_sets/OE6260/mols"
new_gw_path = "/Volumes/LaCie/test_sets/OE6260/E_qp"
new_dft_path = "/Volumes/LaCie/test_sets/OE6260/E_dft"

xyz_files = os.listdir(xyz_path)
num_samples = 0
max_num_atoms = 60
allnum = []
for xyz_file in xyz_files:
    mol = xyz_file[:-4]
    atoms = read(os.path.join(xyz_path, xyz_file))
    num_heavy_atoms = (atoms.get_atomic_numbers() != 1).astype(int).sum()
    #if num_heavy_atoms <= 27:
    #    shutil.copyfile(os.path.join(xyz_path, xyz_file), os.path.join(new_xyz_path, xyz_file))
    #    shutil.copyfile(os.path.join(gw_path, f"{mol}.dat"), os.path.join(new_gw_path, f"{mol}.dat"))
    #    shutil.copyfile(os.path.join(dft_path, f"{mol}.dat"), os.path.join(new_dft_path, f"{mol}.dat"))
    #if 60 < len(atoms):
    #    num_samples += 1
    if len(atoms) <= max_num_atoms:
        num_heavy_atoms = (atoms.get_atomic_numbers() != 1).astype(int).sum()
        allnum.append(num_heavy_atoms)
        #shutil.copyfile(os.path.join(xyz_path, xyz_file), os.path.join(new_xyz_path, xyz_file))
        #shutil.copyfile(os.path.join(gw_path, f"{mol}.dat"), os.path.join(new_gw_path, f"{mol}.dat"))
        #shutil.copyfile(os.path.join(dft_path, f"{mol}.dat"), os.path.join(new_dft_path, f"{mol}.dat"))

In [None]:
num_samples

In [None]:
base_path = "/Users/dario/ViSNetGW/GWSet_10000_train/test"
homo = torch.load(os.path.join(base_path, "HOMO.pt"))
gap = torch.load(os.path.join(base_path, "GAP.pt"))

lumo = torch.zeros_like(homo)
n = homo.shape[0]
for i in range(n):
    h = homo[i, 0].item()
    g = gap[i, 0].item()
    lumo[i, 0] = g + h
torch.save(lumo, os.path.join(base_path, "LUMO.pt"))