In [30]:
%load_ext autoreload
%autoreload 2

import sys
sys.argv = [""]

from models.variance_edm import VarianceEDM
from configs.model_config import EDMConfig
import torch
from configs.dataset_config import DATASET_INFO

from qm9.rdkit_functions import BasicMolecularMetrics
from configs.datasets_config import qm9_with_h

import numpy as np
from tqdm.auto import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Define the device

In [17]:
device = "cuda"

### Load the pretrained weights

In [18]:
edm = VarianceEDM(EDMConfig(device=device))
sd = torch.load("./pretrained/variance_with_h/model.pth", map_location=device)
edm.load_state_dict(sd)
edm.eval();

### Create <s>1000</s> 64 samples

In [19]:
num_molecules = 64  # = 1000
batch_size = 64
mol_sizes = torch.tensor(list(DATASET_INFO["qm9"]["molecule_size_histogram"].keys()), dtype=torch.long, device=device)
mol_size_probs = torch.tensor(list(DATASET_INFO["qm9"]["molecule_size_histogram"].values()), dtype=torch.float, device=device)
samples = edm.sample(num_molecules, batch_size, mol_sizes, mol_size_probs)

  0%|          | 0/64 [00:00<?, ?sample/s]

  0%|          | 0/1000 [00:00<?, ?step/s]

### Get validity and uniqueness

In [27]:
m = BasicMolecularMetrics(qm9_with_h)
(validity, uniqueness, _) , _= m.evaluate([(torch.from_numpy(s[0]), torch.from_numpy(s[1]).argmax(dim=-1)) for s in samples])

[14:03:45] Explicit valence for atom # 6 C, 5, is greater than permitted
[14:03:45] Explicit valence for atom # 8 N, 4, is greater than permitted
[14:03:45] Explicit valence for atom # 10 N, 4, is greater than permitted


Validity over 64 molecules: 93.75%
Uniqueness over 60 valid molecules: 100.00%


[14:03:45] Explicit valence for atom # 12 N, 4, is greater than permitted


### Get stability

In [31]:
from qm9 import bond_analyze

# this is copied from the original repo
def check_stability(positions, atom_type, dataset_info, debug=False):
    assert len(positions.shape) == 2
    assert positions.shape[1] == 3
    atom_decoder = dataset_info['atom_decoder']
    x = positions[:, 0]
    y = positions[:, 1]
    z = positions[:, 2]

    nr_bonds = np.zeros(len(x), dtype='int')

    for i in range(len(x)):
        for j in range(i + 1, len(x)):
            p1 = np.array([x[i], y[i], z[i]])
            p2 = np.array([x[j], y[j], z[j]])
            dist = np.sqrt(np.sum((p1 - p2) ** 2))
            atom1, atom2 = atom_decoder[atom_type[i]], atom_decoder[atom_type[j]]
            pair = sorted([atom_type[i], atom_type[j]])
            if dataset_info['name'] == 'qm7b' or dataset_info['name'] == 'qm9' or dataset_info['name'] == 'qm9_second_half' or dataset_info['name'] == 'qm9_first_half':
                order = bond_analyze.get_bond_order(atom1, atom2, dist)
            elif dataset_info['name'] == 'geom':
                order = bond_analyze.geom_predictor(
                    (atom_decoder[pair[0]], atom_decoder[pair[1]]), dist)
            nr_bonds[i] += order
            nr_bonds[j] += order
    nr_stable_bonds = 0
    for atom_type_i, nr_bonds_i in zip(atom_type, nr_bonds):
        possible_bonds = bond_analyze.allowed_bonds[atom_decoder[atom_type_i]]
        if type(possible_bonds) == int:
            is_stable = possible_bonds == nr_bonds_i
        else:
            is_stable = nr_bonds_i in possible_bonds
        if not is_stable and debug:
            print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type_i], nr_bonds_i))
        nr_stable_bonds += int(is_stable)

    molecule_stable = nr_stable_bonds == len(x)
    return molecule_stable, nr_stable_bonds, len(x)

In [32]:
samples_torch = [(torch.from_numpy(s[0]), torch.from_numpy(s[1]).argmax(dim=-1)) for s in samples]

In [33]:
res = [check_stability(s[0], s[1], qm9_with_h) for s in tqdm(samples_torch)]

  0%|          | 0/64 [00:00<?, ?it/s]

In [36]:
molecule_stabililty = np.mean([r[0] for r in res])
atom_stability = np.sum([r[1] for r in res]) / np.sum([r[2] for r in res])

print(f"Molecule stability was {molecule_stabililty:.2f} and atom stability was {atom_stability:.2f}")

Molecule stability was 0.86 and atom stability was 0.99
