## Brute Force Evaluation of Search Space
In this notebook, we evaluate the IC50 and LogP of all molecules from a large serach space to provide a baseline for the RL agent.

In [1]:
from molgym.mpnn.data import combine_graphs, convert_nx_to_dict
from molgym.mpnn.layers import custom_objects
from molgym.utils.conversions import convert_smiles_to_nx
from rdkit.Chem import Crippen
from rdkit import Chem
from csv import DictReader, DictWriter
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import gzip
import json
import os

In [2]:
search_space = os.path.join('..', '..', 'search-spaces', 'E15.csv')

## Define Key Functions
Load, apply transformations, and write back to disk

In [3]:
def load_molecules(path: str, chunk_size: int = 1024):
    """Load in a chunk of molecules
    
    Args:
        path (str): Path to the search space
        chunk_size (int): Number of molecules to load
    """
    
    with open(path) as fp:
        reader = DictReader(fp, fieldnames=['source', 'identifier', 'smiles'])
        
        # Loop through chunks
        chunk = []
        for entry in reader:
            chunk.append(entry)
            
            # Return chunk if it is big enough
            if len(chunk) == chunk_size:
                yield chunk
                chunk = []

        # Yield what remains
        yield chunk

In [4]:
def compute_logP(chunk: [dict]) -> [dict]:
    """Compute the LogP for each molecule in a chunk"""
    
    for entry in chunk:
        mol = Chem.MolFromSmiles(entry['smiles'])
        entry['logP'] = Crippen.MolLogP(mol)
    return chunk

Load in the MPNN and components needed to featurize SMILES strings

In [5]:
mpnn_dir = os.path.join('..', 'mpnn-training')

In [6]:
with open(os.path.join(mpnn_dir, 'atom_types.json')) as fp:
    atom_types = json.load(fp)

In [7]:
with open(os.path.join(mpnn_dir, 'bond_types.json')) as fp:
    bond_types = json.load(fp)

In [8]:
model = tf.keras.models.load_model(os.path.join(mpnn_dir, 'model.h5'), custom_objects=custom_objects)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


In [9]:
def compute_ic50(chunk: [dict]) -> [dict]:
    """Compute the IC50 of a chunk of molecules"""
    
    # Get the features for each molecule
    batch = []
    tested_mols = []
    for i, entry in enumerate(chunk):
        graph = convert_smiles_to_nx(entry['smiles'])
        try:
            graph_dict = convert_nx_to_dict(graph, atom_types, bond_types)
        except AssertionError:
            continue
        batch.append(graph_dict)
        tested_mols.append(i)
    
    # Prepare in input format
    keys = batch[0].keys()
    batch_dict = {}
    for k in keys:
        batch_dict[k] = np.concatenate([np.atleast_1d(b[k]) for b in batch], axis=0)
    inputs = combine_graphs(batch_dict)
    
    # Compute the IC50
    ic50 = model.predict_on_batch(inputs).numpy()[:, 0]
    
    # Store in in the chunk data
    for i, v in zip(tested_mols, ic50): 
        chunk[i]['IC50_mpnn'] = v
    
    return chunk

In [10]:
def flat_map(gen):
    """Only really used to make the update timer more sensical"""
    for chunk in gen:
        for e in chunk:
            yield e

In [11]:
def write_output(path, gen):
    """Write the output of a processing pipeline to disk"""
    
    # Get the first entry
    gen = flat_map(gen)
    entry = next(gen)
    
    with gzip.open(path, 'wt') as fp:
        # Write the header and first entry
        writer = DictWriter(fp, entry.keys())
        writer.writeheader()
        writer.writerow(entry)
        
        # Keep writing rows
        for entry in tqdm(gen):
            writer.writerow(entry)

## Run it
It's a big data file to process, so we are going to use the Python functional tools to run it as a stream

In [12]:
gen = map(compute_ic50, map(compute_logP, load_molecules(search_space, 1024)))

In [13]:
write_output(os.path.basename(search_space) + '.gz', gen)

15547090it [5:19:33, 810.86it/s] 
