<a href="https://colab.research.google.com/github/kthuang20/BetaLactamaseCNN/blob/main/Beta_Lactamase_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Click the link above or the below link to view this code notebook in Google Colab:

https://colab.research.google.com/gist/kthuang20/64c59b559422625b438bd10f45051a09/beta-lactamase-cnn.ipynb

For faster runtime in Google Colab, click Runtime -> Change runtime type -> click the T4 GPU option

In [None]:
# download necessary packages
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2023.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.4/34.4 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.5


In [None]:
# import data manipulation tools
import zipfile
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw

# import visualization tool
import plotly.express as px
from matplotlib import pyplot as plt

# import modeling tools
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten

# import metrics to evaluate model
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
from tensorflow.keras.metrics import Precision, Recall, BinaryAccuracy

# ML and AI Final Project

Antibiotics are compounds that work by directly killing or inhibiting the growth of the bacteria. For instance, penicillin works by inhibiting an enzyme involved in cell wall synthesis. This weakens the overall integrity of the bacterial cell wall, making the bacteria more susceptible to osmotic pressure changes and resulting in cell lysis [[1]](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6723624/). While antibiotics have been effective for bacterial infections, some bacteria have been shown to contain β-lactamase, another enzyme that can break down and therefore inactivate these antibiotics, rendering them ineffective for bacterial infections [[1]]((https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6723624/). This allows the bacteria to continue to proliferate in the presence of antibiotics, leading to antibiotic resistance. Therefore, it is thought that inhibiting β-lactamase might be a viable option for preventing antibiotic resistance.

The hope is to develop of an approach to help accelerate the discovery of β-lactamase inhibitors that effectively combat antibiotic resistance. While traditional drug discovery methodologies have predominantly relied on quantitative structure-activity relationship (QSAR) modeling [[2]]((https://pubs.acs.org/doi/10.1021/jm4004285), most have primarily based on 1D representations of molecular structure, convolutional neural networks (CNNs) as a promising alternative.  Unlike 1D descriptors, which often require manual selection and extraction of features [[3]]((https://doi.org/10.1186/s12859-018-2523-5), CNNs can  automatically extract hierarchical features from raw input data, such as 2D molecular structures [[3]](https://doi.org/10.1186/s12859-018-2523-5). This allows the CNN to capture both local and global patterns in molecular images, allowing for more comprehensive representation of complex chemical structures and their relationships with biological activity [[3]](https://doi.org/10.1186/s12859-018-2523-5). By leveraging these capabilities, CNNs have the potential to improve predictive accuracy and enable more efficient drug discovery processes by directly learning from the molecular structure of the compounds [[3]](https://doi.org/10.1186/s12859-018-2523-5). Here, a convolutional neural network trained on the chemical structure of compounds known to bind to β-lactamase to predict whether a future compound would be a strong candidate for inhibiting β-lactamase.

## 1. Generate Training Dataset

A total of 136 csv files belonging to 136 different variants of the β-lactamase protein were recorded from the ChEMBL database (version 29).

In [None]:
# download the file
! gdown --id 1HvDDqoBJdNnFg3i14raMes1oedgC_BFs

Downloading...
From: https://drive.google.com/uc?id=1HvDDqoBJdNnFg3i14raMes1oedgC_BFs
To: /content/beta_lactamase_CHEMBL29.zip
100% 1.42M/1.42M [00:00<00:00, 96.8MB/s]


In [None]:
# name of the zip file containing all 136 csv files
file_path = "beta_lactamase_CHEMBL29.zip"
# read in all 136 variants of β-lactamase
zf = zipfile.ZipFile(file_path, "r")
# combine all the compounds that are known to interact with each variant into one dataframe
beta_lactamase_data = pd.concat((pd.read_csv(zf.open(f)) for f in zf.namelist()))
beta_lactamase_data

Unnamed: 0,molecule_chembl_id,canonical_smiles,standard_relation,standard_value,standard_units,standard_type,pchembl_value,target_pref_name,bao_label
0,CHEMBL1730,CO/N=C(\C(=O)N[C@@H]1C(=O)N2C(C(=O)O)=C(COC(C)...,=,10.0,/mM/s,Kcat/Km,,Gil1,assay format
1,CHEMBL996,CO[C@@]1(NC(=O)Cc2cccs2)C(=O)N2C(C(=O)O)=C(COC...,,,,Kcat/Km,,Gil1,assay format
2,CHEMBL617,CC(=O)OCC1=C(C(=O)O)N2C(=O)[C@@H](NC(=O)Cc3ccc...,=,598.0,/mM/s,Kcat/Km,,Gil1,assay format
3,CHEMBL702,CCN1CCN(C(=O)N[C@@H](C(=O)N[C@@H]2C(=O)N3[C@@H...,=,3400.0,/mM/s,Kcat/Km,,Gil1,assay format
4,CHEMBL1449,CC1(C)S[C@@H]2[C@H](NC(=O)[C@H](C(=O)O)c3ccsc3...,=,10000.0,/mM/s,Kcat/Km,,Gil1,assay format
...,...,...,...,...,...,...,...,...,...
13,CHEMBL561555,COC(=O)CC(N)(CC(=O)OC)C(=O)OCc1ccccc1,,,,Inhibition,,Beta-lactamase VIM-4,single protein format
14,CHEMBL561821,NC(CC(=O)OCc1ccccc1)(CC(=O)OCc1ccccc1)C(=O)OCc...,,,,Inhibition,,Beta-lactamase VIM-4,single protein format
15,CHEMBL561896,COC(=O)CC(CC(=O)OC)(NC(=O)Cc1ccccc1)C(=O)OC,,,,Inhibition,,Beta-lactamase VIM-4,single protein format
16,CHEMBL563044,COC(=O)C(CC(=O)OCc1ccccc1)(CC(=O)OCc1ccccc1)NC...,,,,Inhibition,,Beta-lactamase VIM-4,single protein format


In [None]:
# create a dataframe containing only compounds we are certain the bioactivity of
train_data = beta_lactamase_data[beta_lactamase_data['standard_relation'] == '=']
# remove samples without any pchembl values
train_data = train_data[train_data['pchembl_value'].notna()]

# create a boolean series stating where the standard deviation of pchembl values for each compound is less than 2
low_pchembl_std = train_data.groupby('molecule_chembl_id')['pchembl_value'].std() < 2
# store a list containing the compounds that had small standard deviations
cps = low_pchembl_std[low_pchembl_std].index.tolist()
# filter out compounds with a high standard deviation
cols = ['standard_relation', 'standard_type', 'target_pref_name', 'bao_label']
train_data = train_data.loc[train_data['molecule_chembl_id'].isin(cps)].drop(columns=cols, axis=1)

# define aggregation function to remove duplicates by taking the mean pChEMBL value
remove_dup = {'molecule_chembl_id': 'first',
                'canonical_smiles': 'first',
                'standard_value': 'mean',
                'standard_units': 'first',
                'pchembl_value': 'mean'}

# remove duplicates
train_data = train_data.groupby('molecule_chembl_id').agg(remove_dup).reset_index(drop=True)
train_data

Unnamed: 0,molecule_chembl_id,canonical_smiles,standard_value,standard_units,pchembl_value
0,CHEMBL104,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,27500.000000,nM,4.580000
1,CHEMBL1089781,O=S(=O)(NCB(O)O)c1cc2c(Cl)ccc(Cl)c2s1,1997.500000,nM,5.905000
2,CHEMBL1091,CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=...,84217.950000,nM,4.100000
3,CHEMBL109227,OB(O)c1ccc(-c2ccc(B(O)O)cc2)cc1,200.000000,nM,6.700000
4,CHEMBL1126,CC1(C)S[C@@H]2[C@H](NC(=O)Cc3ccccc3)C(=O)N2[C@...,5400.000000,nM,5.290000
...,...,...,...,...,...
791,CHEMBL87686,O=C(O)[C@H](S)Cc1ccc2oc3ccccc3c2c1,4961.505000,nM,6.320000
792,CHEMBL87719,CC1(C)[C@H](C(=O)O)N2C(=O)[C@]3(C[C@@H]3OC3CCC...,270.000000,nM,6.905000
793,CHEMBL891,Cc1onc(-c2ccccc2Cl)c1C(=O)N[C@@H]1C(=O)N2[C@@H...,4343.333333,nM,6.956667
794,CHEMBL9306,O=C([O-])[C@H]1/C(=C/CO)O[C@@H]2CC(=O)N21.[Li+],234.000000,nM,6.785000


In [None]:
# save a csv file for future use
#train_data.to_csv('processed_data.csv')

In [None]:
# show the summary statistics of the pchembl values
sum_stats = train_data['pchembl_value'].describe()
print('Summary Statistics and Quartiles of the pChEMBL Values:')
sum_stats

Summary Statistics and Quartiles of the pChEMBL Values:


count    796.000000
mean       5.757514
std        1.081195
min        2.946667
25%        4.949167
50%        5.480000
75%        6.530250
max        8.800000
Name: pchembl_value, dtype: float64

In [None]:
# create a histogram to show the distribution of pChEMBL values
fig = px.histogram(train_data, x='pchembl_value')

# add title, axis labels
fig.update_layout(title = 'Figure 1. Distribution of pChEMBL Values of Compounds',
                  title_x = 0.5,
                  xaxis_title = 'pChEMBL Value',
                  yaxis_title = 'Number of Compounds',
                  bargap = 0.2)

# show the histogram
fig.show()

Based on the summary statistics, I will use the following to create 2 classes:
* 0-50% quartile: *inactive*
* 50-100% quartile: *active*

In [None]:
### function to classify bioactivity of compound
def classify_bioactivity(bioactivity, threshold):
    ## if the compound has a bioactivity above this threshold,
    if bioactivity > threshold:
        # label it as an active compound
        return 1
    ## otherwise
    else:
        # it is an inactive compound
        return 0

In [None]:
# define the threshold for classifying a compound as active/inactive as the median
threshold = sum_stats.loc['50%']
# add a column containing the labelled output as to whether or not active
train_data['active'] = train_data['pchembl_value'].apply(classify_bioactivity, threshold=threshold)
train_data

Unnamed: 0,molecule_chembl_id,canonical_smiles,standard_value,standard_units,pchembl_value,active
0,CHEMBL104,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,27500.000000,nM,4.580000,0
1,CHEMBL1089781,O=S(=O)(NCB(O)O)c1cc2c(Cl)ccc(Cl)c2s1,1997.500000,nM,5.905000,1
2,CHEMBL1091,CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=...,84217.950000,nM,4.100000,0
3,CHEMBL109227,OB(O)c1ccc(-c2ccc(B(O)O)cc2)cc1,200.000000,nM,6.700000,1
4,CHEMBL1126,CC1(C)S[C@@H]2[C@H](NC(=O)Cc3ccccc3)C(=O)N2[C@...,5400.000000,nM,5.290000,0
...,...,...,...,...,...,...
791,CHEMBL87686,O=C(O)[C@H](S)Cc1ccc2oc3ccccc3c2c1,4961.505000,nM,6.320000,1
792,CHEMBL87719,CC1(C)[C@H](C(=O)O)N2C(=O)[C@]3(C[C@@H]3OC3CCC...,270.000000,nM,6.905000,1
793,CHEMBL891,Cc1onc(-c2ccccc2Cl)c1C(=O)N[C@@H]1C(=O)N2[C@@H...,4343.333333,nM,6.956667,1
794,CHEMBL9306,O=C([O-])[C@H]1/C(=C/CO)O[C@@H]2CC(=O)N21.[Li+],234.000000,nM,6.785000,1


In [None]:
# export the cleaned dataframe as a csv file for future use
#train_data.to_csv('processed_data.csv', index=False)

## 2. Preprocess the Data

In [None]:
### function to generate a 2D image of the compound
def gen_image(smiles):
    ## get the molecule for this smile
    mol = Chem.MolFromSmiles(smiles)
    ## convert this molecule into an image with a standardized size
    img = Draw.MolToImage(mol, size=(256,256))
    ## convert the image into a numpy array of pixels
    img_px = np.array(img)
    return img_px

In [None]:
# return a list of the images of the compounds
mols = train_data['canonical_smiles'].apply(gen_image)
# combine all the numpy array representations of the chemical compounds as a single tensor
stacked_imgs = tf.stack(mols.tolist())
# create a tensorflow dataset from the stacked tensor
dataset = tf.data.Dataset.from_tensor_slices((stacked_imgs, train_data['active']))

# scale images from 0-256 to 0-1
dataset = dataset.map(lambda x, y: (x/255, y))
# shuffle dataset
dataset = dataset.shuffle(buffer_size=len(mols))

In [None]:
### define a function to create the model
def gen_datasets(dataset, batch_size, train_split, val_split, test_split):
    # create batches based on batch size
    batched_dataset = dataset.batch(batch_size=batch_size)
    # store the total number of batches
    nbatches = len(batched_dataset)

    # define the sizes of each dataset
    train_size = int(nbatches * 0.7)
    val_size = int(nbatches * 0.2)
    test_size = int(nbatches * 0.1) + 1

    ## generate the datasets
    train = batched_dataset.take(train_size)
    val = batched_dataset.skip(train_size).take(val_size)
    test = batched_dataset.skip(train_size + val_size).take(test_size)

    return train, val, test

In [None]:
# split the data into through datasets: training, validation, and testing datasets
train, val, test = gen_datasets(dataset, 64, 0.7, 0.2, 0.1)
print('Number of batches in training dataset: ', str(len(train)))
print('Number of batches in validation dataset: ', str(len(val)))
print('Number of batches in testing dataset: ', str(len(test)))

Number of batches in training dataset:  9
Number of batches in validation dataset:  2
Number of batches in testing dataset:  2


## 3. Generate the CNN

In [None]:
### function to create the model
def gen_model():
  ## initiliaze a sequential model
  tf.random.set_seed(42)
  model = Sequential()
  ## add convolutional layers
  model.add(Conv2D(16, (3,3), 1, activation='relu', input_shape=(256, 256, 3)))
  model.add(MaxPooling2D())

  model.add(Conv2D(32, (3,3), 1, activation='relu'))
  model.add(MaxPooling2D())

  model.add(Conv2D(16, (3,3), 1, activation='relu'))
  model.add(MaxPooling2D())

  ## add flatten layer
  model.add(Flatten())

  ## add dense layers
  model.add(Dense(256, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  ## compile model
  model.compile('adam', loss=tf.losses.BinaryCrossentropy(), metrics=['accuracy'])

  ## show model summary (with architecture of model)
  print(model.summary())

  return model

In [None]:
# create the architecture of the CNN
model = gen_model()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 254, 254, 16)      448       
                                                                 
 max_pooling2d (MaxPooling2  (None, 127, 127, 16)      0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 125, 125, 32)      4640      
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 62, 62, 32)        0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 60, 60, 16)        4624      
                                                                 
 max_pooling2d_2 (MaxPoolin  (None, 30, 30, 16)        0

## 4. Train the CNN

In [None]:
# set up a log directory on local drive to store how model performed at each epoch
logdir = 'logs'
tensorboard_callbacks = tf.keras.callbacks.TensorBoard(log_dir=logdir)

# train the model
hist = model.fit(train, epochs=20, validation_data=val, callbacks=[tensorboard_callbacks])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [None]:
# show a dataframe of the results
hist_df = pd.DataFrame(hist.history)
# rename columns
hist_df.columns = ['Training Loss', 'Training Accuracy', 'Validation Loss', 'Validation Accuracy']
hist_df

Unnamed: 0,Training Loss,Training Accuracy,Validation Loss,Validation Accuracy
0,0.937053,0.53125,0.679289,0.578125
1,0.691075,0.539931,0.675226,0.65625
2,0.665947,0.59375,0.63139,0.703125
3,0.633035,0.652778,0.504343,0.773438
4,0.556681,0.699653,0.493084,0.78125
5,0.496169,0.758681,0.397994,0.804688
6,0.470706,0.765625,0.399197,0.820312
7,0.408533,0.817708,0.348969,0.828125
8,0.361022,0.840278,0.302711,0.882812
9,0.291937,0.876736,0.22781,0.921875


In [None]:
### function to compare metric between training and validation dataset
def compare_metric(metric_results, metric, fig_num):
  ## create a scatter plot comparing the training and validation loss over each iteration
  fig = px.line(metric_results,
                x = metric_results.index,
                y = ['Training '+ metric, 'Validation ' + metric],
                markers = True)

  ## add title, axis labels
  fig.update_layout(title = 'Figure ' + str(fig_num) + '. Training and Validation ' + metric,
                    title_x = 0.5,
                    xaxis_title = 'Epoch',
                    yaxis_title = metric,
                    legend_title_text = 'Dataset')

  ## show figure
  fig.show()

In [None]:
# compare loss between training and validation datasets
compare_metric(hist_df, 'Loss', 2)

In [None]:
# compare accuracies between training and validation dataset
compare_metric(hist_df, 'Accuracy', 3)

## 5. Evaluate Performance of CNN

In [None]:
# store all datasets into one dictionary for easier access
all_data = {'Training': train,
            'Validation': val,
            'Test': test}

In [None]:
### function to generate ROC curve
def gen_roc_curve(all_yactual, all_ypred, dataset_type):
  ## get the FPR and TPR for ROC curve
  fpr, tpr, thresholds = roc_curve(all_yactual, all_ypred)
  ## calculate the auc
  auc = roc_auc_score(all_yactual, all_ypred)
  ## generate ROC Curve
  ax.plot(fpr, tpr, label=f'{dataset_type} (AUC = {auc:0.3f})')

  return auc

In [None]:
### function to obtain results from metrics for dataset
def get_metrics(all_data, dataset_type):
  ## initialize the metrics
  precision = Precision()
  recall = Recall()
  acc = BinaryAccuracy()

  ## initialize empty lists to store all the true labels and model predictions
  all_yactual = []
  all_ypred = []

  ## iterate through each batch of dataset
  for batch in all_data[dataset_type].as_numpy_iterator():
    # get the labelled inputs and outputs of all examples
    X, yactual = batch
    # store the model's predictions on the testing dataset
    ypred = model.predict(X)

    # add the results to their respective lists
    all_yactual.extend(yactual)
    all_ypred.extend(ypred)

    # compute and store the metrics for that dataset
    precision.update_state(yactual, ypred)
    recall.update_state(yactual, ypred)
    acc.update_state(yactual, ypred)

  ## generate a ROC curve and calculate AUC across all batches in dataset
  auc = gen_roc_curve(all_yactual, all_ypred, dataset_type)
  pred_labels = (np.array(all_ypred) >= 0.5).astype(int)
  cm = confusion_matrix(all_yactual, pred_labels).tolist()

  ## store results in dictionary
  metric_results = {'Dataset': dataset_type,
                    'Precision': precision.result().numpy(),
                    'Recall': recall.result().numpy(),
                    'Accuracy': acc.result().numpy(),
                    'AUC': auc}

  return metric_results, cm

In [None]:
### initialize a dataframe to store the results on all datasets
metric_results = []
conf_matrices = {}

### create a new figure
fig, ax = plt.subplots(figsize=(8,8))
### iterate through set of dataset
for dataset_type in all_data:
  ## get a dictionary with metric results for that set of dataset
  results, conf_matrix = get_metrics(all_data, dataset_type)
  ## add this as a row in the dataframe
  metric_results.append(results)
  conf_matrices[dataset_type] = conf_matrix

### add labels to figure
ax.set_title('Figure 4. Receiving Operating Curves (ROC Curves)')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend()

### show plot
plt.show()

### convert metrics to a dataframe
metric_results = pd.DataFrame(metric_results)

In [None]:
# show the results
metric_results

Unnamed: 0,Dataset,Precision,Recall,Accuracy,AUC
0,Training,0.982332,0.985816,0.984375,0.997925
1,Validation,0.96875,1.0,0.984375,0.999511
2,Test,0.981481,0.981481,0.978261,0.998538


In [None]:
### function to show confusion matrix for each dataset
def show_cm(cm, dataset, fig_num):
  # show the confusion matrix
  fig = px.imshow(cm[dataset], text_auto=True, color_continuous_scale='Blues')

  # add labels to figure
  fig.update_layout(title_text = f'Figure {fig_num}. Confusion Matrix for {dataset} Datasets',
                    title_x = 0.5,
                    xaxis = dict(tickvals=[0, 1], ticktext=['Inactive', 'Active']),
                    yaxis = dict(tickvals=[0, 1], ticktext=['Inactive', 'Active']),
                    xaxis_title = 'Predicted Labels',
                    yaxis_title = 'True Labels')
  # show plot
  fig.show()

In [None]:
### iterate through each confusion matrix
for idx, dataset in enumerate(conf_matrices):
  ## show heatmap
  show_cm(conf_matrices, dataset, idx+5)

## 6. Save the Model

In [None]:
# import necessary package
#import os

# save the model for future use
#model.save(os.path.join('models', 'BetaLactmaseCNN.h5'))

## References

[1] C. L. Tooke *et al.*, “β-Lactamases and β-Lactamase Inhibitors in the 21st Century,” *J Mol Biol*, vol. 431, no. 18, pp. 3472–3500, Aug. 2019, doi: 10.1016/j.jmb.2019.04.002. Available: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6723624/

[2] A. Cherkasov *et al.*, “QSAR Modeling: Where Have You Been? Where Are You Going To?,” *J. Med. Chem.*, vol. 57, no. 12, pp. 4977–5010, Jun. 2014, doi: 10.1021/jm4004285. Available: https://pubs.acs.org/doi/10.1021/jm4004285

[3] M. Hirohara, Y. Saito, Y. Koda, K. Sato, and Y. Sakakibara, “Convolutional Neural Network Based On SMILES Representation of Compounds for Detecting Chemical Motif,” *BMC Bioinformatics*, vol. 19, no. 19, p. 526, Dec. 2018, doi: 10.1186/s12859-018-2523-5. Available: https://doi.org/10.1186/s12859-018-2523-5


