In [3]:
from argparse import ArgumentParser
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem
import numpy as np
import pickle
import pandas as pd
from tqdm import tqdm
import random
import torch
import yaml

from model.model import GeoMol
from model.featurization import featurize_mol_from_smiles
from torch_geometric.data import Batch
from model.inference import construct_conformers


parser = ArgumentParser()
parser.add_argument('--trained_model_dir', type=str)
parser.add_argument('--out', type=str)
parser.add_argument('--test_csv', type=str)
parser.add_argument('--dataset', type=str, default='qm9')
parser.add_argument('--mmff', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=0)

_StoreAction(option_strings=['--seed'], dest='seed', nargs=None, const=None, default=0, type=<class 'int'>, choices=None, help=None, metavar=None)

In [6]:
args = parser.parse_args('--trained_model_dir trained_models/qm9/ --test_csv data/QM9/test_smiles.csv --dataset qm9'.split())

In [7]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

trained_model_dir = args.trained_model_dir
test_csv = args.test_csv
dataset = args.dataset
mmff = args.mmff

with open(f'{trained_model_dir}/model_parameters.yml') as f:
    model_parameters = yaml.full_load(f)
model = GeoMol(**model_parameters)

state_dict = torch.load(f'{trained_model_dir}/best_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model.eval()

test_data = pd.read_csv(test_csv)

In [9]:
tg_data

Data(
  x=[19, 44],
  edge_index=[2, 36],
  edge_attr=[36, 4],
  neighbors={
    0=[2],
    1=[2],
    2=[2],
    3=[2],
    4=[4],
    5=[4],
    6=[4],
    7=[4],
    8=[2]
  },
  chiral_tag=[19],
  name='C#CC#C[C@@H](CC)CO',
  edge_index_dihedral_pairs=[2, 8]
)

In [8]:
conformer_dict = {}
for smi, n_confs in tqdm(test_data.values):
    
    # create data object (skip smiles rdkit can't handle)
    tg_data = featurize_mol_from_smiles(smi, dataset=dataset)
    if not tg_data:
        print(f'failed to featurize SMILES: {smi}')
        continue
    
    # generate model predictions
    data = Batch.from_data_list([tg_data])
    model(data, inference=True, n_model_confs=n_confs*2)
    
    # set coords
    n_atoms = tg_data.x.size(0)
    model_coords = construct_conformers(data, model)
    mols = []
    for x in model_coords.split(1, dim=1):
        mol = Chem.AddHs(Chem.MolFromSmiles(smi))
        coords = x.squeeze(1).double().cpu().detach().numpy()
        mol.AddConformer(Chem.Conformer(n_atoms), assignId=True)
        for i in range(n_atoms):
            mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2]))

        if mmff:
            try:
                AllChem.MMFFOptimizeMoleculeConfs(mol, mmffVariant='MMFF94s')
            except Exception as e:
                pass
        mols.append(mol)
        
    conformer_dict[smi] = mols
    
# save to file
if args.out:
    with open(f'{args.out}', 'wb') as f:
        pickle.dump(conformer_dict, f)
else:
    suffix = '_ff' if mmff else ''
    with open(f'{trained_model_dir}/test_mols{suffix}.pkl', 'wb') as f:
        pickle.dump(conformer_dict, f)

  0%|                                                  | 0/1000 [00:00<?, ?it/s]


TypeError: argument of type 'int' is not iterable