In [None]:
import numpy as np
import pandas as pd
import os

from rdkit import Chem
from rdkit.Chem import MACCSkeys, rdFingerprintGenerator
from rdkit import DataStructs
from wrapMordred import mordredWrapper

import chemprop

from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import joblib

import matplotlib.pyplot as plt

In [None]:
np.random.seed(1234)

In [None]:
# endpoint = 'skin-sensitization'
endpoint = 'eye-irritation'

loc = r'D:\School\Semester3\Seminar - Reproducibility\seminar-toxicity\data'
endpoint_loc = os.path.join(loc, endpoint)
model = r'D:\School\Semester3\Seminar - Reproducibility\seminar-toxicity\src\models'
model_loc = os.path.join(model, endpoint)

In [None]:
filename = 'train.csv'
df_train = pd.read_csv(os.path.join(endpoint_loc, filename))

In [None]:
df_train.shape

In [None]:
df_train.head()

In [None]:
filename = 'val.csv'
df_val = pd.read_csv(os.path.join(endpoint_loc, filename))

In [None]:
df_val.shape

In [None]:
df_val.head()

In [None]:
train_smiles = df_train['SMILES'].to_numpy()
train_labels = df_train['Activity'].to_numpy()

val_smiles = df_val['SMILES'].to_numpy()
val_labels = df_val['Activity'].to_numpy()

In [None]:
print('val size smiles :', val_smiles.shape)
print('val size labels :', val_labels.shape)
print('pos samples in val size :', val_labels[val_labels == 1].shape)
print('neg samples in val size :', val_labels[val_labels == 0].shape)

In [None]:
def get_MPNN_pred(endpoint_loc, model_loc, val_smiles, val_labels, filename='val.csv'):
    smiles_input = []
    for smiles in val_smiles:
        smiles_input.append([smiles])
        
    arguments = [
        '--test_path', '/dev/null', 
        '--preds_path', '/dev/null',
        '--checkpoint_dir', model_loc,
        # '--smiles_columns', 'SMILES',
        '--features_generator', 'rdkit_2d_normalized', 
        '--no_features_scaling'
    ]

    args = chemprop.args.PredictArgs().parse_args(arguments)
    preds = chemprop.train.make_predictions(args=args, smiles=smiles_input)

    # y_pred = (np.array(preds).flatten()[np.where(np.array(preds).flatten() != 'Invalid SMILES')].astype(np.float32) > 0.5).astype(np.int64)
    y_pred = np.array(preds).flatten()[np.where(np.array(preds).flatten() != 'Invalid SMILES')].astype(np.float32)
    y_true = val_labels[np.where(np.array(preds).flatten() != 'Invalid SMILES')]

    # y_pred = (np.array(preds).flatten() > 0.5).astype(np.int64)
    # y_true = val_labels

    return y_pred, y_true

In [None]:
y_pred_MPNN, y_true_MPNN = get_MPNN_pred(endpoint_loc, model_loc, val_smiles, val_labels, 'val.csv')

In [None]:
y_pred_MPNN[0], y_true_MPNN[0]

In [None]:
def dist_to_model(pred, true, dist):
    mask = np.logical_or(pred >= 1-dist, pred <= dist)

    return (pred[mask] > 0.5).astype(np.int64) , true[mask], np.sum(mask)/pred.shape[0]

In [None]:
plt.figure()
measurement = {}
measurement['ACC'] = []
measurement['SEN'] = []
measurement['SPE'] = []
coverage = []
xlabels = []
for dist in [0.1, 0.2, 0.3, 0.4, 0.5]:

    y_pred, y_true, cov = dist_to_model(y_pred_MPNN, y_true_MPNN, dist)

    tn, fp, fn, tp = confusion_matrix(y_pred, y_true).ravel()

    ACC = (tp + tn)/(tp + tn + fn + fp)
    SEN = tp/(tp + fn)
    SPE = tn/(tn + fp)

    xlabels.append(dist)
    coverage.append(cov)
    measurement['ACC'].append(ACC)
    measurement['SEN'].append(SEN)
    measurement['SPE'].append(SPE)

x = np.arange(len(xlabels))  # the label locations
width = 0.25  # the width of the bars
multiplier = 0

fig, ax = plt.subplots(layout='constrained')

colours = ['blue', 'red', 'green']
for key, value in measurement.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, value, width, label=key, color=colours[multiplier])
    multiplier += 1

ax.plot(x+0.25, coverage, ls='--', marker='o', c='k')

ax.set_title('Dist to MPNN Model')
ax.set_xticks(x + width, xlabels)
ax.set_yticks(np.arange(0,11)/10)
ax.legend(loc='upper center', ncols=3)
ax.set_ylim(0, 1)
ax.set_axisbelow(True)
ax.grid(axis='y')

plt.show()

In [None]:
measurement

In [None]:
coverage