# Message-passing neural network (MPNN) for molecular property prediction

In [7]:
#!pip -q install rdkit-pypi
#!pip -q install pandas
#!pip -q install Pillow
#!pip -q install matplotlib
#!pip -q install pydot
#!sudo apt-get -qq install graphviz

### Import packages

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
# Temporary suppress tf logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
import numpy as np
import pandas as pd
import warnings
import matplotlib.pyplot as plt

# Temporary suppress warnings
#warnings.filterwarnings("ignore")


np.random.seed(42)
tf.random.set_seed(42)

In [77]:
from constants import *

## Data

In [53]:
df = pd.read_csv("../data/0_raw/data.csv").reset_index(drop=True)
df.head(
)

Unnamed: 0,P1,mol_id,smiles
0,1,CID2999678,Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C
1,0,CID2999679,Cn1ccnc1SCC(=O)Nc1ccc(Oc2ccccc2)cc1
2,1,CID2999672,COc1cc2c(cc1NC(=O)CN1C(=O)NC3(CCc4ccccc43)C1=O...
3,0,CID5390002,O=C1/C(=C/NC2CCS(=O)(=O)C2)c2ccccc2C(=O)N1c1cc...
4,1,CID2999670,NC(=O)NC(Cc1ccccc1)C(=O)O


In [75]:
from data import validate_dataframe
df = validate_dataframe(df)

INFO:root: Data Validation | Dataset imbalance | Proportions: {1: 0.82, 0: 0.18}
INFO:root: Data Validation | Finished!


## Split to train/validation/test


Although scaffold splitting is recommended in our case (see
[here](https://www.blopig.com/blog/2021/06/out-of-distribution-generalisation-and-scaffold-splitting-in-molecular-property-prediction/)), for simplicity, random strattified splittings were
performed.

In [76]:
from data import split_data
split_data(data_path=INPUT_DATA_PATH,
           output_path=INTERMEDIATE_DATA_PATH,
           test_only=False,
          )

INFO:root: Data Validation | Dataset imbalance | Proportions: {1: 0.82, 0: 0.18}
INFO:root: Data Validation | Finished!
INFO:root: Data Splitting | Train: 0.7, Valid: 0.15, Test: 0.15
INFO:root: Data Splitting | Finished!


('data/1_intermediate/data_train.csv',
 'data/1_intermediate/data_valid.csv',
 'data/1_intermediate/data_test.csv')

In [229]:
len(df_train), len(df_valid), len(df_test)

(3499, 750, 750)

In [272]:
INTERMEDIATE_DATA_PATH/DATA_TRAIN_FILENAME

PosixPath('data/1_intermediate/data_train.pkl')

## Featurization, Graph Generation & DataSet Creation


In [277]:
from data import get_mpnn_dataset
train_dataset, atom_dim, bond_dim = get_mpnn_dataset(df_train, return_dims=True)
valid_dataset = get_mpnn_dataset(df_valid,)
test_dataset = get_mpnn_dataset(df_test,)

### Handle imbalance

'None'

In [80]:
initial_bias, class_weight = get_imbalance_params(df_train)
initial_bias, class_weight

(array([1.52765758]), {0: 2.8036858974358974, 1: 0.6085217391304348})

In [83]:
from modeling import get_imbalance_params, MPNNModel
from data import get_mpnn_dataset


In [159]:
aa = False
bb = not aa
bb

True

In [164]:
aa = "bb"
if type(aa) == str:
    print("ok")

ok


In [90]:
def train(data_train_path, data_valid_path, save_model_path="models/my_model", handle_imbalance=False):
    df_train = pd.read_csv(data_train_path)
    df_train = validate_dataframe(df_train)
    
    train_dataset, atom_dim, bond_dim = get_mpnn_dataset(df_train, return_dims=True)
    
    if data_valid_path is not None:
        df_valid = pd.read_csv(data_valid_path)
        valid_dataset = get_mpnn_dataset(df_valid,)
    
    initial_bias=None
    class_weights=None
    if handle_imbalance==True:
        initial_bias, class_weight = get_imbalance_params(df_train)
    
    model = MPNNModel(
        atom_dim=atom_dim, bond_dim=bond_dim, output_bias=initial_bias,
    )
    
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4),
        metrics=[tf.keras.metrics.AUC(name="AUC")],
    )
    
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                  patience=5, min_lr=1e-7)
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        history = model.fit(
            train_dataset,
            validation_data=valid_dataset,
            epochs=MAX_EPOCHS,
            verbose=2,
            callbacks=[reduce_lr, early_stopping],
            class_weight=class_weight,
        )
    model.save(save_model_path)
    return model, history

In [88]:
train(INTERMEDIATE_DATA_PATH/DATA_VALID_FILENAME, INTERMEDIATE_DATA_PATH/DATA_VALID_FILENAME, 
      save_model_path="models/my_model",
      handle_imbalance=True,)

Epoch 1/2
24/24 - 21s - loss: 0.7253 - AUC: 0.4509 - val_loss: 0.6910 - val_AUC: 0.5909 - lr: 5.0000e-04 - 21s/epoch - 882ms/step
Epoch 2/2
24/24 - 10s - loss: 0.6938 - AUC: 0.5112 - val_loss: 0.6385 - val_AUC: 0.6757 - lr: 5.0000e-04 - 10s/epoch - 423ms/step




INFO:tensorflow:Assets written to: models/my_model/assets


INFO:tensorflow:Assets written to: models/my_model/assets


(<keras.engine.functional.Functional at 0x318e316d0>,
 <keras.callbacks.History at 0x318e2af40>)

In [255]:
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                              patience=5, min_lr=1e-7)
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
)

if validation_dataset is None:
    history = model.fit(
    train_dataset,
    validation_split=VALIDATION_SIZE,
    epochs=MAX_EPOCHS,
    verbose=2,
    callbacks=[reduce_lr, early_stopping],
    class_weight={0: 2.0, 1: 0.5},
    )
else:
    history = model.fit(
        train_dataset,
        validation_data=valid_dataset,
        epochs=MAX_EPOCHS,
        verbose=2,
        callbacks=[reduce_lr, early_stopping],
        class_weight=class_weight,
    )

Epoch 1/2
110/110 - 242s - loss: 0.5113 - AUC: 0.6349 - val_loss: 0.5428 - val_AUC: 0.6695 - lr: 5.0000e-04 - 242s/epoch - 2s/step
Epoch 2/2
110/110 - 156s - loss: 0.4895 - AUC: 0.6860 - val_loss: 0.5843 - val_AUC: 0.6633 - lr: 5.0000e-04 - 156s/epoch - 1s/step


In [287]:
def evaluate(data_path, model_path):
    df_data = pd.read_pickle(data_path)
    dataset = get_mpnn_dataset(df_data)
    model = tf.keras.models.load_model(model_path)
    return model.evaluate(dataset)

In [288]:
loss, acc = evaluate(data_path=INTERMEDIATE_DATA_PATH/DATA_VALID_FILENAME,
        model_path="models/my_model",)



In [315]:
def predict(model_path, data_path=None, smiles=None):
    model = tf.keras.models.load_model(model_path)
    if (data_path is None) and (smiles is None):
        return "Error!"
    if data_path is not None:
        df_data = pd.read_pickle(data_path)
        dataset = get_mpnn_dataset(df_data)
        return  model.predict(dataset)
    if smiles is not None:
        dataset = get_mpnn_dataset(smiles)
        return model.predict(dataset)

In [316]:
predictions = predict(model_path="models/my_model",
        data_path=INTERMEDIATE_DATA_PATH/DATA_TEST_FILENAME,
       )

In [321]:
preds = predict(model_path="models/my_model",
        smiles=['CC1=C(C(=O)Nc2cc(-c3cccc(F)c3)[nH]n2)C2(CCCCC2)OC1=O',],
       )

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(history.history["AUC"], label="train AUC")
plt.plot(history.history["val_AUC"], label="valid AUC")
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("AUC", fontsize=16)
plt.legend(fontsize=16)

### Predicting

In [None]:
molecules = [molecule_from_smiles(df.smiles.values[index]) for index in test_index]
y_true = [df.p_np.values[index] for index in test_index]
y_pred = tf.squeeze(mpnn.predict(test_dataset), axis=1)

legends = [f"y_true/y_pred = {y_true[i]}/{y_pred[i]:.2f}" for i in range(len(y_true))]
MolsToGridImage(molecules, molsPerRow=4, legends=legends)

In [105]:
df.iloc[test_index].head(1).smiles.values

array(['Cc1ccc(-n2cc(C(=O)c3cc(Cl)ccc3O)cc(C#N)c2=O)cc1'], dtype=object)

In [133]:
sample = graphs_from_smiles(df.iloc[134:135].smiles.values)
len(sample)

3

In [134]:
sample_dataset = MPNNDataset(sample, None)

In [135]:
sample_preds = tf.squeeze(mpnn.predict(sample_dataset), axis=1)

In [136]:
sample_preds

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.891305], dtype=float32)>

In [139]:
len(sample)

3

In [145]:
(1, 2, 3) + (4,)

(1, 2, 3, 4)

In [97]:
def evaluate(data_path, model_path):
    df_data = pd.read_csv(data_path)
    df_data = validate_dataframe(df_data)
    dataset = get_mpnn_dataset(df_data)
    model = tf.keras.models.load_model(model_path)
    return model.evaluate(dataset)

In [98]:
loss, acc = evaluate(data_path=INTERMEDIATE_DATA_PATH/DATA_VALID_FILENAME,
        model_path="models/my_model",)

INFO:root: Data Validation | Dataset imbalance | Proportions: {1: 0.82, 0: 0.18}
INFO:root: Data Validation | Finished!




In [151]:
def predict(model_path, data_path=None, smiles=None):
    model = tf.keras.models.load_model(model_path)
    if (data_path is None) and (smiles is None):
        raise Exception('No data input is given!')
    if data_path is not None:
        df_data = pd.read_csv(data_path)
        df_data = validate_dataframe(df_data, predict=True)
        if COL_TARGET not in df_data:
            df_data[COL_TARGET] = 0
        dataset = get_mpnn_dataset(df_data)
        return  model.predict(dataset)
    if smiles is not None:
        dataset = get_mpnn_dataset(smiles)
        return model.predict(dataset)

In [152]:
predictions = predict(model_path="models/my_model",
        data_path=INTERMEDIATE_DATA_PATH/DATA_TEST_FILENAME,
       )

INFO:root: Data Validation | Finished!


In [155]:
df["smiles"]

0         Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C
1                     Cn1ccnc1SCC(=O)Nc1ccc(Oc2ccccc2)cc1
2       COc1cc2c(cc1NC(=O)CN1C(=O)NC3(CCc4ccccc43)C1=O...
3       O=C1/C(=C/NC2CCS(=O)(=O)C2)c2ccccc2C(=O)N1c1cc...
4                               NC(=O)NC(Cc1ccccc1)C(=O)O
                              ...                        
4994         CC1CCC(NC(=O)CN2CCCN(Cc3ccc(F)cc3Cl)C2=O)CC1
4995                  Cc1cccc(-n2cnc(C(=O)Nc3cccnc3)c2)n1
4996                  COc1ccc(CCNC(=O)c2noc3c2CCCC3)cc1OC
4997                             COCc1ccc2oc(C(=O)O)cc2c1
4998        Cc1ccc(/C=C2\C(=O)NC(=O)N(Cc3ccccc3Cl)C2=O)o1
Name: smiles, Length: 4999, dtype: object

In [150]:
predictions

array([[0.5159536 ],
       [0.56485164],
       [0.53658825],
       [0.5545654 ],
       [0.5613068 ],
       [0.54563105],
       [0.49215645],
       [0.5497382 ],
       [0.51019347],
       [0.5732583 ],
       [0.52465665],
       [0.53291994],
       [0.5340168 ],
       [0.5367876 ],
       [0.54328936],
       [0.550211  ],
       [0.54901946],
       [0.55827874],
       [0.5512804 ],
       [0.5823129 ],
       [0.52714807],
       [0.5333784 ],
       [0.5569229 ],
       [0.5905736 ],
       [0.5232183 ],
       [0.5190081 ],
       [0.52322274],
       [0.5147129 ],
       [0.5402329 ],
       [0.578955  ],
       [0.5393267 ],
       [0.52713054],
       [0.58897245],
       [0.5020621 ],
       [0.5352524 ],
       [0.53178567],
       [0.5548853 ],
       [0.53727424],
       [0.5210109 ],
       [0.5075571 ],
       [0.57831657],
       [0.53182054],
       [0.56721884],
       [0.55539465],
       [0.5616628 ],
       [0.56586933],
       [0.5590251 ],
       [0.535

In [157]:
predict(model_path="models/my_model", smiles=["Cc1cccc(N2CCN(C(=O)C34CC5CC(CC(C5)C3)C4)CC2)c1C",
                                             "Cn1ccnc1SCC(=O)Nc1ccc(Oc2ccccc2)cc1"])

array([[0.54842335],
       [0.5827183 ]], dtype=float32)