In [1]:
"""
CheXNet applied to chest X-ray images classifying pneumonia vs normal.

Reimplementation of the paper: 
https://doi.org/10.48550/arXiv.1711.05225

The dataset can be downloaded from here:
https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
"""

import tensorflow as tf
from loss import WeightedCrossEntropyBinaryLoss
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense
from keras.models import Model
import prepare_data
import os
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score

In [3]:
# Local path to the input that has to be specified
input_path = 'C:/Users/marta/OneDrive/Desktop/Osnabruck/ImplementingANNswithTensorFlow/FinalProject/chest_xray/'

# Checkpoint file path
checkpoint_filepath = './checkpoint/weights.hdf5' 

# Train data path
train_data_path = input_path+'train'

# Validation data path
val_data_path = input_path+'val'

# Testing data path
test_data_path=input_path+'test'

# Tensorboard log files path
log_dir='./logs'

# Hyperparameters set in accordance with the methods of the original paper
img_dims = 224
n_epochs = 10
batch_size = 16 
output_size = 1

In [4]:
# Adding files from the training dataset since the current validation set is only 0.03% of the training set
prepare_data.split_data(0.1, input_path)

Data has already been split. Skipping data split.


In [5]:
# Getting the data
train_gen, val_gen, test_data, test_labels = prepare_data.prepare_data(img_dims, batch_size, input_path)

Found 4695 images belonging to 2 classes.
Found 537 images belonging to 2 classes.


In [6]:
class CheXNet(tf.keras.Model):
  """
  The CheXNet model that uses DenseNet121 as its backbone.
  """

  def __init__ (self):
    """
    The constructor instantiates the weights and the model.
    """
    super().__init__()

    # Instantiate weights and steps and set them to None
    self.zero_weight = None
    self.one_weight = None
    self.train_steps = None
    self.val_steps = None

    # get_model() will initialize this to DenseNet121 model
    self.model = None
  
  def get_weights(self, train_data_path):
    """
    Computes class distribution of pneumonia vs normal images.

    Args:
      train_data_path: path to training data.
      val_data_path: path to validation data.
    """

    # Count images in each class in the train data
    n_normal = len(os.listdir(train_data_path + '/NORMAL'))
    n_pneumonia = len(os.listdir(train_data_path + '/PNEUMONIA'))

    # Compute class distribution
    self.one_weight = float(n_normal)/(n_normal+n_pneumonia)
    self.zero_weight = float(n_pneumonia)/(n_normal+n_pneumonia)


  def get_model(self):

    # DenseNet121 expects number of channels to be 3
    input = Input(shape=(img_dims, img_dims, 3), batch_size=batch_size)

    # using pretrained DenseNet121 as the foundation of the model
    base_model = tf.keras.applications.densenet.DenseNet121(include_top=False, weights='imagenet',
                                                            input_shape=(img_dims, img_dims, 3), pooling='avg')
    
    # Add custom output layers
    x = base_model.output
    x = tf.keras.layers.Dense(output_size, activation='sigmoid')(x)

    self.model = tf.keras.models.Model(inputs=base_model.input, outputs=x)

    # Use weighted binary crossentropy loss
    loss = WeightedCrossEntropyBinaryLoss(self.zero_weight, self.one_weight)

    # Compile the model
    self.model.compile(optimizer=tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.999),
                       loss=loss.weighted_binary_crossentropy,
                       metrics=['accuracy'])
    
    return self.model

# Instantiate the CheXNet model
chexnet = CheXNet()

# Compute class distribution of pneumonia vs normal images
chexnet.get_weights(train_data_path)

# Create and compile the DenseNet121 model
model = chexnet.get_model()

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=1)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.1, patience=1, mode='min')
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,
    mode='min',
    save_best_only=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='epoch')


### Training the model
The model has already been trained via the cell below.

For time-efficiency it can be loaded for evaluation on the test set from a checkpoint file in the next cell of this notebook.

In [7]:
# # Fitting the model
# history = model.fit(train_gen,
#                     epochs=n_epochs,
#                     batch_size=batch_size,
#                     steps_per_epoch=train_gen.samples // batch_size,
#                     validation_steps=val_gen.samples // batch_size,
#                     validation_data=val_gen,
#                     callbacks=[reduce_lr, model_checkpoint, tensorboard_callback])

Epoch 1/10
Epoch 2/10
Epoch 2: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 3/10
Epoch 4/10
Epoch 4: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
Epoch 5/10
Epoch 5: ReduceLROnPlateau reducing learning rate to 1.0000000656873453e-06.
Epoch 6/10
Epoch 6: ReduceLROnPlateau reducing learning rate to 1.0000001111620805e-07.
Epoch 7/10
Epoch 7: ReduceLROnPlateau reducing learning rate to 1.000000082740371e-08.
Epoch 8/10
Epoch 8: ReduceLROnPlateau reducing learning rate to 1.000000082740371e-09.
Epoch 9/10
Epoch 9: ReduceLROnPlateau reducing learning rate to 1.000000082740371e-10.
Epoch 10/10
Epoch 10: ReduceLROnPlateau reducing learning rate to 1.000000082740371e-11.


### Model evaluation

In [9]:
# Load the model with the lowest validation loss
model.load_weights(checkpoint_filepath)

# Compute predictions and round to obtain binary predictions
yhat = np.round(model.predict(test_data, batch_size=batch_size))

# Compute accuracy, confusion matrix, and metrics from confusion matrix
acc = accuracy_score(test_labels, yhat) * 100
tn, fp, fn, tp = confusion_matrix(test_labels, yhat).ravel()
precision, recall, f1_score, _ = precision_recall_fscore_support(test_labels, yhat, average='binary')

# Print the results
print("\nModel evaluation\n") 
print(f"Confusion Matrix:\n{confusion_matrix(test_labels, yhat)}\n")
print(f"Accuracy: {acc}%")
print(f"Precision: {precision * 100}%")
print(f"Recall: {recall * 100}%")
print(f"F1 Score: {f1_score * 100}")



Model evaluation

Confusion Matrix:
[[179  55]
 [  4 386]]

Accuracy: 90.5448717948718%
Precision: 87.52834467120182%
Recall: 98.97435897435898%
F1 Score: 92.90012033694344
