In [None]:
import sys
sys.path.append('..')
import os

In [None]:
from energy_minimization.data import ConfDataset
from energy_minimization.utils.utils import (add_prop, compute_alligment,
                                             get_energy, set_conformer,
                                             compute_desciptors3d)

In [None]:
from energy_minimization.models import COSMIC

In [None]:
path_to_data = '/path/to/data'

In [None]:
import torch
from torch_geometric.data import DataLoader, Batch

dl = DataLoader(
        ConfDataset(path_to_data,
                    os.path.join(path_to_data, 'summary_qm9_preprocessed.json'),
                    split='test',
                    conditions='none',
                    task_type='argmin'),
        batch_size=128, num_workers=1)

In [None]:
device = 'cpu'

In [None]:
model_params = {'latent_size': 3,
                'node_hidden_size': 128,
                'edge_hidden_size': 64,
                'num_gaussians': 64,
                'num_backbone_layers': 4,
                'num_main_layers': 6,
                'num_refiner_steps': 10,
                'lambda_cosloss': 0.5,
                'lambda_mxloss': 1.0,
                'wgan_energy_loss_coeff': 0.1,
                'ae_num_encoder_layers': 4,

                'vae_kl_beta': 0.03,

                'aae_num_discriminator_layers': 4,
                'aae_discr_coeff': 0.01,

                'wgan_num_discr_inter': 6,
                'wgan_lambda_gp': 10.,
                'wgan_discr_coeff': 0.01,
                'wgan_num_mols_energy': 32,

                'num_warmup_iteration': 400,
                
                'use_wgan_part': True,
                'ae_part_type': 'aae'
               }

In [None]:
model = COSMIC(conditions='none', **model_params)

In [None]:
weights = torch.load("../saved_models/cosmic_qm9.ckpt", map_location=device)

In [None]:
model.load_state_dict({k.replace('model.', ''):weights['state_dict'][k] for k in weights['state_dict'].keys()})
model = model.to(device)
model.eval();

In [None]:
import tqdm
from rdkit import Chem

num_samples = 50
sampled_confs = []

for i, batch in enumerate(tqdm.tqdm(dl)):
    batch = batch.to(device)
    for _ in range(num_samples):
        with torch.no_grad():
            nodes_out = model.sample(batch)

        add_prop(batch, 'cartesian_pred', nodes_out)

        for i, d in enumerate(batch.to_data_list()):
            sampled_confs.append(Chem.AddHs(set_conformer(d.mol, d.cartesian_pred), addCoords=True))

In [None]:
from rdkit import Chem
Chem.Draw.MolsToGridImage(sampled_confs[:48], molsPerRow=8)

In [None]:
import tqdm

wr = Chem.SDWriter('./sampled_sdf/cosmic_qm9.sdf')

for m in tqdm.tqdm(sampled_confs):
    wr.write(m)
wr.flush()
wr.close()