<a href="https://colab.research.google.com/github/lockiultra/SCAMT/blob/main/MPNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install rdkit-pypi

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import re
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import MolsToGridImage

warnings.filterwarnings("ignore")
RDLogger.DisableLog("rdApp.*")

np.random.seed(42)
tf.random.set_seed(42)

# Featurizers

class Featurizer:
    def __init__(self, allowable_sets):
        self.dim = 0
        self.feature_mapping = {}
        for k, s in allowable_sets.items():
            s = sorted(list(s))
            self.feature_mapping[k] = dict(zip(s, range(self.dim, self.dim + len(s))))
            self.dim += len(s)

    def encode(self, inputs):
        output = np.zeros((self.dim,))
        for name_feature, feature_mapping in self.feature_mapping.items():
            feature = getattr(self, name_feature)(inputs)
            if feature not in feature_mapping:
                continue
            output[feature_mapping[feature]] = 1.0
        return output


class AtomFeaturizer(Featurizer):
    def __init__(self, allowable_sets):
        super().__init__(allowable_sets)

    def symbol(self, atom):
        return atom.GetSymbol()

    def n_valence(self, atom):
        return atom.GetTotalValence()

    def n_hydrogens(self, atom):
        return atom.GetTotalNumHs()

    def hybridization(self, atom):
        return atom.GetHybridization().name.lower()


class BondFeaturizer(Featurizer):
    def __init__(self, allowable_sets):
        super().__init__(allowable_sets)

    def encode(self, bond):
        output = np.zeros((self.dim,))
        if bond is None:
            output[-1] = 1.0
            return output
        output = super().encode(bond)
        return output

    def bond_type(self, bond):
        return bond.GetBondType().name.lower()

    def conjugated(self, bond):
        return bond.GetIsConjugated()
# Model

class EdgeNetwork(layers.Layer):
    def build(self, input_shape):
        self.atom_dim = input_shape[0][-1]
        self.bond_dim = input_shape[1][-1]
        self.kernel  = self.add_weight(
            shape=(self.bond_dim, self.atom_dim * self.atom_dim),
            initializer="glorot_uniform",
            name="kernel",
        )
        self.bias = self.add_weight(
            shape=(self.atom_dim * self.atom_dim),
            initializer="zeros",
            name="bias",
        )
        self.built = True

    def call(self, inputs):
        atom_features, bond_features, pair_indices = inputs
        bond_features = tf.matmul(bond_features, self.kernel) + self.bias
        bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim))
        atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1])
        atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1)
        transformed_features = tf.matmul(bond_features, atom_features_neighbors)
        transformed_features = tf.squeeze(transformed_features, axis=-1)
        aggregated_features = tf.math.unsorted_segment_sum(
            transformed_features,
            pair_indices[:, 0],
            num_segments=tf.shape(atom_features)[0]
        )
        return aggregated_features

class MessagePassing(layers.Layer):
    def __init__(self, units, steps=4, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.steps = steps

    def build(self, input_shape):
        self.atom_dim = input_shape[0][-1]
        self.message_step = EdgeNetwork()
        self.pad_length = max(0, self.units - self.atom_dim)
        self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)
        self.built = True

    def call(self, inputs):
        atom_features, bond_features, pair_indices = inputs
        atom_features_updated = tf.pad(atom_features, [[0, 0], [self.pad_length, 0]])
        for i in range(self.steps):
            atom_features_aggregated = self.message_step([atom_features_updated, bond_features, pair_indices])
            atom_features_updated, _ = self.update_step(atom_features_aggregated, atom_features_updated)
        return atom_features_updated

class PartitionPadding(layers.Layer):
    def __init__(self, batch_size, **kwargs):
        super().__init__(**kwargs)
        self.batch_size = batch_size

    def call(self, inputs):
        atom_features, molecule_indicator = inputs
        atom_features_partitioned = tf.dynamic_partition(atom_features, molecule_indicator, self.batch_size)
        num_atoms = [tf.shape(x)[0] for x in atom_features_partitioned]
        max_num_atoms = tf.reduce_max(num_atoms)
        atom_features_stacked = tf.stack(
            [
                tf.pad(f, [(0, max_num_atoms - n), (0, 0)]) for f, n in zip(atom_features_partitioned, num_atoms)
            ],
            axis=0
        )
        gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0)
        gather_indices = tf.squeeze(gather_indices, axis=-1)
        return tf.gather(atom_features_stacked, gather_indices, axis=0)

class TransformerEncoderReadout(layers.Layer):
    def __init__(self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs):
        super().__init__(**kwargs)
        self.partition_padding = PartitionPadding(batch_size)
        self.attention = layers.MultiHeadAttention(num_heads, embed_dim)
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.average_pooling = layers.GlobalAveragePooling1D()

    def call(self, inputs):
        x = self.partition_padding(inputs)
        padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1)
        padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
        attention_output = self.attention(x, x, attention_mask=padding_mask)
        proj_input = self.layernorm_1(x + attention_output)
        proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input))
        return self.average_pooling(proj_output)

def MPNNModel(atom_dim, bond_dim, batch_size=128, message_units=64, message_steps=4, num_attention_heads=8, dense_units=512):
    atom_features = layers.Input((atom_dim,), dtype=tf.float32, name="atom_features")
    bond_features = layers.Input((bond_dim,), dtype=tf.float32, name="bond_features")
    pair_indices = layers.Input((2,), dtype=tf.int32, name="pair_indices")
    molecule_indicator = layers.Input((), dtype=tf.int32, name="molecule_indicator")
    x = MessagePassing(message_units, message_steps)([atom_features, bond_features, pair_indices])
    x = TransformerEncoderReadout(num_attention_heads, message_units, dense_units, batch_size)([x, molecule_indicator])
    x = layers.Dense(dense_units, activation='elu')(x)
    x = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs=[atom_features, bond_features, pair_indices, molecule_indicator], outputs=[x])
    return model



In [35]:
class DiseasePipeline:
  def __init__(self, data: pd.DataFrame):
    self.diseases = data.Disease.unique()
    self.atom_featurizer = AtomFeaturizer(
      allowable_sets={
        "symbol": {'B', 'Br', 'C', 'Ca', 'Cl', 'F', 'H', 'I', 'N', 'Na', 'O', 'P', 'S'},
        "n_valence": {0, 1, 2, 3, 4, 5, 6},
        "n_hydrogens": {0, 1, 2, 3, 4},
        "hybridization": {'s', 'sp', 'sp2', 'sp3'},
      }
    )
    self.bond_featurizer = BondFeaturizer(
      allowable_sets={
        "bond_type": {'single', 'double', 'triple', 'aromatic'},
        "conjugated": {False, True},
      }
    )
    self.data = data.copy()
    self.curr_df = None
    self.x_train, self.y_train, self.x_val, self.y_val = self.__get_train_val_data(data)
    self.models = {disease: MPNNModel(atom_dim=self.x_train[0][0][0].shape[0], bond_dim=self.x_train[1][0][0].shape[0]) for disease in self.diseases}
    self.is_trained = False
    self.train_history = dict()

  def train(self):
    for i, disease in enumerate(self.diseases):
      print(f'\n({i}) ======={disease}=======\n')
      self.curr_df = self.data.copy()
      self.curr_df['Disease'] = self.curr_df['Disease'].replace({x: 1 if x == disease else 0 for x in self.diseases})
      self.curr_df = self.curr_df.drop_duplicates(subset=['Drug'])
      x_train, y_train, x_val, y_val = self.__get_train_val_data(self.curr_df)
      train_dataset = self.__get_dataset(x_train, y_train, disease)
      val_dataset = self.__get_dataset(x_val, y_val, disease)
      self.models[disease].compile(
        loss=keras.losses.BinaryCrossentropy(),
        optimizer=keras.optimizers.AdamW(learning_rate=3e-4),
        metrics=[keras.metrics.AUC(name='AUC')],
      )
      history = self.models[disease].fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=40,
        verbose=2,
      )
      self.train_history[disease] = history
    self.is_trained = True

  def predict(self, smiles):
    if not self.is_trained:
      print('Error! Model is not trained')
      return
    result = {disease: None for disease in self.diseases}
    g = self.graph_from_smiles(smiles)
    dataset = tf.data.Dataset.from_tensors(((g), (1))).map(self.prepare_batch, -1).prefetch(-1)
    for disease, model in zip(self.models.keys(), self.models.values()):
      result[disease] = model.predict(dataset)
    return result


  def __get_train_val_data(self, data):
    data = data.dropna()
    permutation = np.random.permutation(np.arange(data.shape[0]))
    train_index = permutation#[:int(data.shape[0] * 0.8)]
    x_train = self.graph_from_smiles(data.iloc[train_index].Smiles)
    y_train = data.iloc[train_index].Disease
    val_index = permutation[int(data.shape[0] * 0.8):]
    x_val = self.graph_from_smiles(data.iloc[val_index].Smiles)
    y_val = data.iloc[val_index].Disease
    return (x_train, y_train, x_val, y_val)

  def __get_dataset(self, X, y, curr_disease, batch_size=128, shuffle=False):
    # replace_disease_dict = {
    #   x: 1 if x == curr_disease else 0 for x in self.diseases
    # }
    # y = y.replace(replace_disease_dict)
    dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
    if shuffle:
      dataset = dataset.shuffle(1024)
    return dataset.batch(batch_size).map(self.prepare_batch, -1).prefetch(-1)

  def prepare_batch(self, x_batch, y_batch):
    atom_features, bond_features, pair_indices = x_batch
    num_atoms = atom_features.row_lengths()
    num_bonds = bond_features.row_lengths()
    molecule_indices = tf.range(len(num_atoms))
    molecule_indicator = tf.repeat(molecule_indices, num_atoms)
    gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
    increment = tf.cumsum(num_atoms[:-1])
    increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
    pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    pair_indices = pair_indices + increment[:, tf.newaxis]
    atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
    return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch

  def molecule_from_smiles(self, smiles):
    molecule = Chem.MolFromSmiles(smiles, sanitize=False)
    flag = Chem.SanitizeMol(molecule, catchErrors=True)
    if flag != Chem.SanitizeFlags.SANITIZE_NONE:
        Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
    Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
    return molecule

  def graph_from_molecule(self, molecule):
      atom_features = []
      bond_features = []
      pair_indices = []
      for atom in molecule.GetAtoms():
          atom_features.append(self.atom_featurizer.encode(atom))
          pair_indices.append((atom.GetIdx(), atom.GetIdx()))
          bond_features.append(self.bond_featurizer.encode(None))
          for neighbor in atom.GetNeighbors():
              bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
              pair_indices.append((atom.GetIdx(), neighbor.GetIdx()))
              bond_features.append(self.bond_featurizer.encode(bond))
      return np.array(atom_features), np.array(bond_features), np.array(pair_indices)

  def graph_from_smiles(self, smiles_list):
      atom_features_list = []
      bond_features_list = []
      pair_indices_list = []
      for smiles in smiles_list:
          molecule = self.molecule_from_smiles(smiles)
          atom_features, bond_features, pair_indices = self.graph_from_molecule(molecule)
          atom_features_list.append(atom_features)
          bond_features_list.append(bond_features)
          pair_indices_list.append(pair_indices)
      return (
          tf.ragged.constant(atom_features_list, dtype=tf.float32),
          tf.ragged.constant(bond_features_list, dtype=tf.float32),
          tf.ragged.constant(pair_indices_list, dtype=tf.int64),
      )

In [28]:
df = pd.read_csv('data.csv')

In [37]:
df

Unnamed: 0.1,Unnamed: 0,Conditions,Drug,Smiles,Study Status,Phases,Disease
0,0,Metastatic Colorectal Cancer,Irinotecan,CCC1=C2CN3C(=CC4=C(C3=O)COC(=O)C4(CC)O)C2=NC5=...,COMPLETED,PHASE1|PHASE2,digestive_system_disease
1,1,Ulcerative Colitis|Left-sided Ulcerative Colit...,Mesalazine,C1=CC(=C(C=C1N)C(=O)O)O,UNKNOWN,PHASE2,digestive_system_disease
2,2,Ulcerative Colitis|Left-sided Ulcerative Colit...,Mesalamine,C1=CC(=C(C=C1N)C(=O)O)O,UNKNOWN,PHASE2,digestive_system_disease
3,5,HER-2 Positive Gastric Cancer|Metastatic Cancer,Capecitabine,CCCCCOC(=O)NC1=NC(=O)N(C=C1F)C2C(C(C(O2)C)O)O,UNKNOWN,PHASE2,digestive_system_disease
4,6,HER-2 Positive Gastric Cancer|Metastatic Cancer,Cisplatin,N.N.Cl[Pt]Cl,UNKNOWN,PHASE2,digestive_system_disease
...,...,...,...,...,...,...,...
4870,4211,Metastatic Renal Cell Cancer|Stage IV Renal Ce...,Ibrutinib,C=CC(=O)N1CCCC(C1)N2C3=NC=NC(=C3C(=N2)C4=CC=C(...,COMPLETED,PHASE1|PHASE2,urinary_system_disease
4871,4214,Interstitial Fibrosis|Kidney Transplant; Compl...,Fingolimod,CCCCCCCCC1=CC=C(C=C1)CCC(CO)(CO)N,ENROLLING_BY_INVITATION,PHASE2,urinary_system_disease
4872,4221,Clear Cell Renal Cell Carcinoma,Imatinib,CC1=C(C=C(C=C1)NC(=O)C2=CC=C(C=C2)CN3CCN(CC3)C...,COMPLETED,PHASE1|PHASE2,urinary_system_disease
4873,4223,Cannabis|Chronic Kidney Diseases|Dialysis,Dronabinol,CCCCCC1=CC(=C2C3C=C(CCC3C(OC2=C1)(C)C)C)O,RECRUITING,PHASE1,urinary_system_disease


In [36]:
disease_pipeline = DiseasePipeline(df)

In [38]:
disease_pipeline.train()



Epoch 1/40
15/15 - 17s - loss: 0.7403 - AUC: 0.5035 - val_loss: 0.6918 - val_AUC: 0.5243 - 17s/epoch - 1s/step
Epoch 2/40
15/15 - 3s - loss: 0.6918 - AUC: 0.5178 - val_loss: 0.6866 - val_AUC: 0.5249 - 3s/epoch - 216ms/step
Epoch 3/40
15/15 - 3s - loss: 0.6902 - AUC: 0.5220 - val_loss: 0.6856 - val_AUC: 0.5335 - 3s/epoch - 214ms/step
Epoch 4/40
15/15 - 3s - loss: 0.6896 - AUC: 0.5269 - val_loss: 0.6850 - val_AUC: 0.5341 - 3s/epoch - 227ms/step
Epoch 5/40
15/15 - 3s - loss: 0.6887 - AUC: 0.5363 - val_loss: 0.6842 - val_AUC: 0.5397 - 3s/epoch - 213ms/step
Epoch 6/40
15/15 - 3s - loss: 0.6868 - AUC: 0.5485 - val_loss: 0.6825 - val_AUC: 0.5468 - 3s/epoch - 227ms/step
Epoch 7/40
15/15 - 3s - loss: 0.6842 - AUC: 0.5685 - val_loss: 0.6815 - val_AUC: 0.5664 - 3s/epoch - 220ms/step
Epoch 8/40
15/15 - 3s - loss: 0.6835 - AUC: 0.5724 - val_loss: 0.6790 - val_AUC: 0.5748 - 3s/epoch - 216ms/step
Epoch 9/40
15/15 - 3s - loss: 0.6798 - AUC: 0.5866 - val_loss: 0.6766 - val_AUC: 0.5850 - 3s/epoch - 22

In [41]:
disease_pipeline.predict([df.iloc[133].Smiles]), df.iloc[133].Disease











({'digestive_system_disease': array([[0.44990876]], dtype=float32),
  'skin_and_connective_tissue_disease': array([[0.00795242]], dtype=float32),
  'cardiovascular_disease': array([[0.22503713]], dtype=float32),
  'immune_system_disease': array([[0.1450851]], dtype=float32),
  'mental_and_behavioural_disorder': array([[0.0030801]], dtype=float32),
  'metabolic_disease': array([[0.09662353]], dtype=float32),
  'nervous_system_disease': array([[0.03983595]], dtype=float32),
  'urinary_system_disease': array([[0.20622356]], dtype=float32)},
 'digestive_system_disease')

In [44]:
disease_pipeline.train_history['cardiovascular_disease']

<keras.callbacks.History at 0x7b3976452ad0>

In [45]:
for model in disease_pipeline.models:
  with open(f'./drive/MyDrive/Models/{model}', 'wb') as f:
    pickle.dump(disease_pipeline.models[model], f)

In [46]:
!ls ./drive/MyDrive/Models

cardiovascular_disease		 mental_disease
digestive_disease		 metabolic_disease
digestive_system_disease	 nervous_system_disease
immune_disease			 skin_and_connective_tissue_disease
immune_system_disease		 urinary_disease
mental_and_behavioural_disorder  urinary_system_disease
