<a href="https://colab.research.google.com/github/lockiultra/SCAMT/blob/main/MPNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install rdkit-pypi

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [31]:
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import re
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import MolsToGridImage

warnings.filterwarnings("ignore")
RDLogger.DisableLog("rdApp.*")

np.random.seed(42)
tf.random.set_seed(42)

# Featurizers

class Featurizer:
    def __init__(self, allowable_sets):
        self.dim = 0
        self.feature_mapping = {}
        for k, s in allowable_sets.items():
            s = sorted(list(s))
            self.feature_mapping[k] = dict(zip(s, range(self.dim, self.dim + len(s))))
            self.dim += len(s)

    def encode(self, inputs):
        output = np.zeros((self.dim,))
        for name_feature, feature_mapping in self.feature_mapping.items():
            feature = getattr(self, name_feature)(inputs)
            if feature not in feature_mapping:
                continue
            output[feature_mapping[feature]] = 1.0
        return output


class AtomFeaturizer(Featurizer):
    def __init__(self, allowable_sets):
        super().__init__(allowable_sets)

    def symbol(self, atom):
        return atom.GetSymbol()

    def n_valence(self, atom):
        return atom.GetTotalValence()

    def n_hydrogens(self, atom):
        return atom.GetTotalNumHs()

    def hybridization(self, atom):
        return atom.GetHybridization().name.lower()


class BondFeaturizer(Featurizer):
    def __init__(self, allowable_sets):
        super().__init__(allowable_sets)

    def encode(self, bond):
        output = np.zeros((self.dim,))
        if bond is None:
            output[-1] = 1.0
            return output
        output = super().encode(bond)
        return output

    def bond_type(self, bond):
        return bond.GetBondType().name.lower()

    def conjugated(self, bond):
        return bond.GetIsConjugated()
# Model

class EdgeNetwork(layers.Layer):
    def build(self, input_shape):
        self.atom_dim = input_shape[0][-1]
        self.bond_dim = input_shape[1][-1]
        self.kernel  = self.add_weight(
            shape=(self.bond_dim, self.atom_dim * self.atom_dim),
            initializer="glorot_uniform",
            name="kernel",
        )
        self.bias = self.add_weight(
            shape=(self.atom_dim * self.atom_dim),
            initializer="zeros",
            name="bias",
        )
        self.built = True

    def call(self, inputs):
        atom_features, bond_features, pair_indices = inputs
        bond_features = tf.matmul(bond_features, self.kernel) + self.bias
        bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim))
        atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1])
        atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1)
        transformed_features = tf.matmul(bond_features, atom_features_neighbors)
        transformed_features = tf.squeeze(transformed_features, axis=-1)
        aggregated_features = tf.math.unsorted_segment_sum(
            transformed_features,
            pair_indices[:, 0],
            num_segments=tf.shape(atom_features)[0]
        )
        return aggregated_features

class MessagePassing(layers.Layer):
    def __init__(self, units, steps=4, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.steps = steps

    def build(self, input_shape):
        self.atom_dim = input_shape[0][-1]
        self.message_step = EdgeNetwork()
        self.pad_length = max(0, self.units - self.atom_dim)
        self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)
        self.built = True

    def call(self, inputs):
        atom_features, bond_features, pair_indices = inputs
        atom_features_updated = tf.pad(atom_features, [[0, 0], [self.pad_length, 0]])
        for i in range(self.steps):
            atom_features_aggregated = self.message_step([atom_features_updated, bond_features, pair_indices])
            atom_features_updated, _ = self.update_step(atom_features_aggregated, atom_features_updated)
        return atom_features_updated

class PartitionPadding(layers.Layer):
    def __init__(self, batch_size, **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size

    def call(self, inputs):
        atom_features, molecule_indicator = inputs
        atom_features_partitioned = tf.dynamic_partition(atom_features, molecule_indicator, self.batch_size)
        num_atoms = [tf.shape(x)[0] for x in atom_features_partitioned]
        max_num_atoms = tf.reduce_max(num_atoms)
        atom_features_stacked = tf.stack(
            [
                tf.pad(f, [(0, max_num_atoms - n), (0, 0)]) for f, n in zip(atom_features_partitioned, num_atoms)
            ],
            axis=0
        )
        gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0)
        gather_indices = tf.squeeze(gather_indices, axis=-1)
        return tf.gather(atom_features_stacked, gather_indices, axis=0)

class TransformerEncoderReadout(layers.Layer):
    def __init__(self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs):
        super().__init__(**kwargs)
        self.partition_padding = PartitionPadding(batch_size)
        self.attention = layers.MultiHeadAttention(num_heads, embed_dim)
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.average_pooling = layers.GlobalAveragePooling1D()

    def call(self, inputs):
        x = self.partition_padding(inputs)
        padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1)
        padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
        attention_output = self.attention(x, x, attention_mask=padding_mask)
        proj_input = self.layernorm_1(x + attention_output)
        proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input))
        return self.average_pooling(proj_output)

def MPNNModel(atom_dim, bond_dim, batch_size=32, message_units=64, message_steps=4, num_attention_heads=8, dense_units=512):
    atom_features = layers.Input((atom_dim,), dtype=tf.float32, name="atom_features")
    bond_features = layers.Input((bond_dim,), dtype=tf.float32, name="bond_features")
    pair_indices = layers.Input((2,), dtype=tf.int32, name="pair_indices")
    molecule_indicator = layers.Input((), dtype=tf.int32, name="molecule_indicator")
    x = MessagePassing(message_units, message_steps)([atom_features, bond_features, pair_indices])
    x = TransformerEncoderReadout(num_attention_heads, message_units, dense_units, batch_size)([x, molecule_indicator])
    x = layers.Dense(dense_units, activation="relu")(x)
    x = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs=[atom_features, bond_features, pair_indices, molecule_indicator], outputs=[x])
    return model



In [3]:
pattern = r'\.csv$'
files = [f for f in os.listdir('./drive/MyDrive/data/molecule') if re.search(pattern, f)]

In [4]:
data = pd.DataFrame(columns=['Name', 'Smiles', 'Disease'])

In [5]:
for f in files:
  tmp_df = pd.read_csv(f'./drive/MyDrive/data/molecule/{f}', usecols=['Name', 'Smiles'], sep=';')
  tmp_df['Disease'] = f[:-4]
  frames = [data, tmp_df]
  data = pd.concat(frames)

In [6]:
data

Unnamed: 0,Name,Smiles,Disease
0,ERLOSAMIDE,COC[C@@H](NC(C)=O)C(=O)NCc1ccccc1,mental_disease
1,FLUNARIZINE,Fc1ccc(C(c2ccc(F)cc2)N2CCN(C/C=C/c3ccccc3)CC2)cc1,mental_disease
2,PROXIBARBAL,C=CCC1(CC(C)O)C(=O)NC(=O)NC1=O,mental_disease
3,DESVENLAFAXINE,CN(C)CC(c1ccc(O)cc1)C1(O)CCCCC1,mental_disease
4,RISPERIDONE,Cc1nc2n(c(=O)c1CCN1CCC(c3noc4cc(F)ccc34)CC1)CCCC2,mental_disease
...,...,...,...
128,CABERGOLINE,C=CCN1C[C@H](C(=O)N(CCCN(C)C)C(=O)NCC)C[C@@H]2...,urinary_disease
129,CICLOPIROX,Cc1cc(C2CCCCC2)n(O)c(=O)c1,urinary_disease
130,DEQUALINIUM,Cc1cc(N)c2ccccc2[n+]1CCCCCCCCCC[n+]1c(C)cc(N)c...,urinary_disease
131,SERTACONAZOLE,Clc1ccc(C(Cn2ccnc2)OCc2csc3c(Cl)cccc23)c(Cl)c1,urinary_disease


In [24]:
class DiseasePipeline:
  def __init__(self, data: pd.DataFrame):
    self.diseases = data.Disease.unique()
    self.atom_featurizer = AtomFeaturizer(
      allowable_sets={
        "symbol": {'B', 'Br', 'C', 'Ca', 'Cl', 'F', 'H', 'I', 'N', 'Na', 'O', 'P', 'S'},
        "n_valence": {0, 1, 2, 3, 4, 5, 6},
        "n_hydrogens": {0, 1, 2, 3, 4},
        "hybridization": {'s', 'sp', 'sp2', 'sp3'},
      }
    )
    self.bond_featurizer = BondFeaturizer(
      allowable_sets={
        "bond_type": {'single', 'double', 'triple', 'aromatic'},
        "conjugated": {False, True},
      }
    )
    self.x_train, self.y_train, self.x_val, self.y_val = self.__get_train_val_data(data)
    self.models = {disease: MPNNModel(atom_dim=self.x_train[0][0][0].shape[0], bond_dim=self.x_train[1][0][0].shape[0]) for disease in self.diseases}
    self.is_trained = False

  def train(self):
    for i, disease in enumerate(self.diseases):
      print(f'\n{i}======={disease}=======\n')
      train_dataset = self.__get_dataset(self.x_train, self.y_train, disease)
      val_dataset = self.__get_dataset(self.x_val, self.y_val, disease)
      self.models[disease].compile(
        loss=keras.losses.BinaryCrossentropy(),
        optimizer=keras.optimizers.Adam(learning_rate=5e-4),
        metrics=[keras.metrics.AUC(name='AUC')],
      )
      history = self.models[disease].fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=40,
        verbose=2,
      )
    self.is_trained = True

  def predict(self, smiles):
    if not self.is_trained:
      print('Error! Model is not trained')
      return
    result = {disease: None for disease in self.diseases}
    g = self.graph_from_smiles(smiles)
    dataset = tf.data.Dataset.from_tensors(((g), (1))).map(self.prepare_batch, -1).prefetch(-1)
    for disease, model in zip(self.models.keys(), self.models.values()):
      result[disease] = model.predict(dataset)
    return result


  def __get_train_val_data(self, data):
    data = data.dropna()
    permutation = np.random.permutation(np.arange(data.shape[0]))
    train_index = permutation[:int(data.shape[0] * 0.8)]
    x_train = self.graph_from_smiles(data.iloc[train_index].Smiles)
    y_train = data.iloc[train_index].Disease
    val_index = permutation[int(data.shape[0] * 0.8):]
    x_val = self.graph_from_smiles(data.iloc[val_index].Smiles)
    y_val = data.iloc[val_index].Disease
    return (x_train, y_train, x_val, y_val)

  def __get_dataset(self, X, y, curr_disease, batch_size=32, shuffle=False):
    replace_disease_dict = {
      x: 1 if x == curr_disease else 0 for x in self.diseases
    }
    y = y.replace(replace_disease_dict)
    dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
    if shuffle:
      dataset = dataset.shuffle(1024)
    return dataset.batch(batch_size).map(self.prepare_batch, -1).prefetch(-1)

  def prepare_batch(self, x_batch, y_batch):
    atom_features, bond_features, pair_indices = x_batch
    num_atoms = atom_features.row_lengths()
    num_bonds = bond_features.row_lengths()
    molecule_indices = tf.range(len(num_atoms))
    molecule_indicator = tf.repeat(molecule_indices, num_atoms)
    gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
    increment = tf.cumsum(num_atoms[:-1])
    increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
    pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    pair_indices = pair_indices + increment[:, tf.newaxis]
    atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch

  def molecule_from_smiles(self, smiles):
    molecule = Chem.MolFromSmiles(smiles, sanitize=False)
    flag = Chem.SanitizeMol(molecule, catchErrors=True)
    if flag != Chem.SanitizeFlags.SANITIZE_NONE:
        Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
    Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
    return molecule

  def graph_from_molecule(self, molecule):
      atom_features = []
      bond_features = []
      pair_indices = []
      for atom in molecule.GetAtoms():
          atom_features.append(self.atom_featurizer.encode(atom))
          pair_indices.append((atom.GetIdx(), atom.GetIdx()))
          bond_features.append(self.bond_featurizer.encode(None))
          for neighbor in atom.GetNeighbors():
              bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
              pair_indices.append((atom.GetIdx(), neighbor.GetIdx()))
              bond_features.append(self.bond_featurizer.encode(bond))
      return np.array(atom_features), np.array(bond_features), np.array(pair_indices)

  def graph_from_smiles(self, smiles_list):
      atom_features_list = []
      bond_features_list = []
      pair_indices_list = []
      for smiles in smiles_list:
          molecule = self.molecule_from_smiles(smiles)
          atom_features, bond_features, pair_indices = self.graph_from_molecule(molecule)
          atom_features_list.append(atom_features)
          bond_features_list.append(bond_features)
          pair_indices_list.append(pair_indices)
      return (
          tf.ragged.constant(atom_features_list, dtype=tf.float32),
          tf.ragged.constant(bond_features_list, dtype=tf.float32),
          tf.ragged.constant(pair_indices_list, dtype=tf.int64),
      )

In [25]:
disease_pipeline = DiseasePipeline(data)

In [26]:
disease_pipeline.train()



Epoch 1/40
20/20 - 21s - loss: 0.6828 - AUC: 0.5958 - val_loss: 0.6139 - val_AUC: 0.6890 - 21s/epoch - 1s/step
Epoch 2/40
20/20 - 12s - loss: 0.6367 - AUC: 0.6720 - val_loss: 0.5971 - val_AUC: 0.7221 - 12s/epoch - 589ms/step
Epoch 3/40
20/20 - 10s - loss: 0.6078 - AUC: 0.7190 - val_loss: 0.5517 - val_AUC: 0.8059 - 10s/epoch - 506ms/step
Epoch 4/40
20/20 - 12s - loss: 0.5840 - AUC: 0.7485 - val_loss: 0.4929 - val_AUC: 0.8391 - 12s/epoch - 576ms/step
Epoch 5/40
20/20 - 12s - loss: 0.5553 - AUC: 0.7841 - val_loss: 0.4974 - val_AUC: 0.8380 - 12s/epoch - 593ms/step
Epoch 6/40
20/20 - 13s - loss: 0.5515 - AUC: 0.7813 - val_loss: 0.5080 - val_AUC: 0.8363 - 13s/epoch - 638ms/step
Epoch 7/40
20/20 - 10s - loss: 0.5518 - AUC: 0.7796 - val_loss: 0.5108 - val_AUC: 0.8564 - 10s/epoch - 518ms/step
Epoch 8/40
20/20 - 12s - loss: 0.5527 - AUC: 0.7846 - val_loss: 0.4825 - val_AUC: 0.8410 - 12s/epoch - 602ms/step
Epoch 9/40
20/20 - 12s - loss: 0.5538 - AUC: 0.7880 - val_loss: 0.4731 - val_AUC: 0.8478 

In [29]:
disease_pipeline.predict([data.iloc[133].Smiles]), data.iloc[133].Disease



({'mental_disease': array([[0.86554986]], dtype=float32),
  'cardiovascular_disease': array([[0.3225131]], dtype=float32),
  'digestive_disease': array([[0.00685264]], dtype=float32),
  'immune_disease': array([[0.471586]], dtype=float32),
  'urinary_disease': array([[0.04342876]], dtype=float32)},
 'mental_disease')

In [32]:
for model in disease_pipeline.models:
  with open(f'./drive/MyDrive/Models/{model}', 'wb') as f:
    pickle.dump(disease_pipeline.models[model], f)