In [3]:
# # Install VIT model libraries
# !pip install vit_keras
# !pip install tensorflow_addons

Import Libraries

In [1]:
import pandas as pd
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Functions to extract images with their labels

In [2]:
# Load data from csv files in chunks
def load_data(csv):
    # usecols: only load Path and Pneumonia columns from csv file
    chunks = pd.read_csv(csv, usecols=['Path', 'Pneumonia'], chunksize=100)
    dfs = []
    for df in chunks:
        # remove rows with Pneumonia = NaN
        df = df[df['Pneumonia'].notna()]
        # remove rows with Pneumonia = -1
        df = df[df['Pneumonia'] != -1]
        dfs.append(df)
    return pd.concat(dfs)

# extract image from path
def load_image(path, size):
    # read image from path
    img = cv2.imread(path)
    # change color space from gray to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # resize image to the desired size
    img = cv2.resize(img, (size, size))
    return img

# extract image and label from dataframe
def extract_data(df):
    # iterate over dataframe
    for index, row in df.iterrows():
        # load image from path at both sizes
        img_224 = load_image(row['Path'], size=224)
        img_256 = load_image(row['Path'], size=256)
        # get label
        label = row['Pneumonia']
        # create generator for both image sizes and label to use less memory and speed up training
        yield img_224, img_256, label


# call load_data function and give path to csv files
train_df_chunks = load_data("CheXpert-v1.0-small/train.csv")
valid_df_chunks = load_data("CheXpert-v1.0-small/valid.csv")

Imbalanced Classes

In [3]:
# print total number of zero and one labels in train and validation data
train_df_chunks['Pneumonia'].value_counts(), valid_df_chunks['Pneumonia'].value_counts()

(1.0    6039
 0.0    2799
 Name: Pneumonia, dtype: int64,
 0.0    226
 1.0      8
 Name: Pneumonia, dtype: int64)

Extract Data

In [4]:
# Define generators for train and valid data using extract_data function for 224 and 256 images
train_gen = extract_data(train_df_chunks)
valid_gen = extract_data(valid_df_chunks)

# Extract data and labels using generators for 224 image
train_data_224, train_data_256, train_labels = zip(*train_gen)
valid_data_224, valid_data_256, valid_labels = zip(*valid_gen)

# Convert lists to numpy arrays
train_data_224 = np.array(train_data_224)
train_data_256 = np.array(train_data_256)
train_labels = np.array(train_labels)
valid_data_224 = np.array(valid_data_224)
valid_data_256 = np.array(valid_data_256)
valid_labels = np.array(valid_labels)

print(train_data_224.shape)
print(valid_data_224.shape)
print(train_data_256.shape)
print(valid_data_256.shape)
print(train_labels.shape)
print(valid_labels.shape)

(8838, 224, 224, 3)
(234, 224, 224, 3)
(8838, 256, 256, 3)
(234, 256, 256, 3)
(8838,)
(234,)


Image Augmentation using Flip, Rotate, Zoom, and Shift

In [5]:
# create function to augment data to balance classes with three arguments: data, labels and batch size
def augment_data(data_224, data_256, labels, batch_size):
    # Calculate number of zeros and ones
    zeros = np.count_nonzero(labels == 0)
    ones = np.count_nonzero(labels == 1)
    # Calculate difference between zeros and ones
    diff = abs(zeros - ones)
    # Create ImageDataGenerator object for augmentation
    datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode="nearest"
    )
    # Check if zeros are more than ones
    if zeros > ones:
        # Augment ones to balance classes
        augment_indices = np.random.choice(np.where(labels == 1)[0], diff)
        # Create empty arrays for augmented data
        new_data_224 = np.empty((diff, ) + data_224.shape[1:], dtype=data_224.dtype)
        new_data_256 = np.empty((diff, ) + data_256.shape[1:], dtype=data_256.dtype)
        # Create empty array for augmented labels
        new_labels = np.ones(diff, dtype=labels.dtype)
        # Iterate over augment_indices
        for i, index in enumerate(augment_indices):
            # Reshape images to 4D tensors for augmentation
            img_224 = data_224[index].reshape((1, ) + data_224[index].shape)
            img_256 = data_256[index].reshape((1, ) + data_256[index].shape)
            # Generate augmented images
            aug_imgs = datagen.flow(img_224, img_256, batch_size=batch_size)
            # Append augmented images to data lists
            new_data_224[i:i+batch_size] = aug_imgs[0][0]
            new_data_256[i:i+batch_size] = aug_imgs[0][1]
        # Concatenate augmented data and labels with original data and labels
        data_224 = np.concatenate((data_224, new_data_224), axis=0)
        data_256 = np.concatenate((data_256, new_data_256), axis=0)
        labels = np.concatenate((labels, new_labels), axis=0)
    # Check if ones are more than zeros
    elif ones > zeros:
        # Augment zeros to balance classes
        augment_indices = np.random.choice(np.where(labels == 0)[0], diff)
        new_data_224 = np.empty((diff, ) + data_224.shape[1:], dtype=data_224.dtype)
        new_data_256 = np.empty((diff, ) + data_256.shape[1:], dtype=data_256.dtype)
        new_labels = np.zeros(diff, dtype=labels.dtype)
        # Iterate over augment_indices
        for i, index in enumerate(augment_indices):
            # Reshape images to 4D tensors for augmentation
            img_224 = data_224[index].reshape((1, ) + data_224[index].shape)
            img_256 = data_256[index].reshape((1, ) + data_256[index].shape)
            # Generate augmented images
            aug_imgs = datagen.flow(img_224, img_256, batch_size=batch_size)
            # Append augmented images to data lists
            new_data_224[i:i+batch_size] = aug_imgs[0][0]
            new_data_256[i:i+batch_size] = aug_imgs[0][1]
        # Concatenate augmented data and labels with original data and labels
        data_224 = np.concatenate((data_224, new_data_224), axis=0)
        data_256 = np.concatenate((data_256, new_data_256), axis=0)
        labels = np.concatenate((labels, new_labels), axis=0)
    return data_224, data_256, labels

# Augment train data and labels with batch size of 16
train_data_224, train_data_256, train_labels = augment_data(train_data_224, train_data_256, train_labels, batch_size=16)
valid_data_224, valid_data_256, valid_labels = augment_data(valid_data_224, valid_data_256, valid_labels, batch_size=16)

Shuffle Data

In [6]:
# shuffle train and validation data
def shuffle_data(data, labels):
    # Get number of rows in data
    indices = np.arange(len(data))
    # Shuffle indices
    np.random.shuffle(indices)
    # Shuffle data and labels using shuffled indices
    data = data[indices]
    labels = labels[indices]
    return data, labels

# Shuffle train data and labels
train_data_224, train_labels = shuffle_data(train_data_224, train_labels)
valid_data_224, valid_labels = shuffle_data(valid_data_224, valid_labels)

Balanced Classes

In [7]:
# create function with one argument: array
def check_balance(arr):
    # Calculate number of zeros and ones
    num_zeros = sum(arr == 0)
    num_ones = sum(arr == 1)
    # Check if zeros and ones are equal
    if num_zeros == num_ones:
        # Return True if balanced
        return True, num_zeros, num_ones
    else:
        # Return False if not balanced
        return False, num_zeros, num_ones

# Check if training labels are balanced after augmentation
train_balanced = check_balance(train_labels)
print("Training data is balanced:", train_balanced)

# Check if validation labels are balanced after augmentation
valid_balanced = check_balance(valid_labels)
print("Validation data is balanced:", valid_balanced)

Training data is balanced: (True, 6039, 6039)
Validation data is balanced: (True, 226, 226)


Model MVCNN + VIT

In [8]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, GlobalMaxPooling2D
from vit_keras import vit

# Define CNN 1 layers with input shape of 224x224x3
cnn_input1 = Input(shape=(224, 224, 3))
# Create 3 convolutional layers with 32, 64 and 128 filters respectively and kernel size of 3x3 and activation function of relu
# Create 3 max pooling layers with pool size of 2x2
cnn_layer1 = Conv2D(filters=32, kernel_size=3, activation='relu')(cnn_input1)
cnn_layer1 = MaxPooling2D(pool_size=2)(cnn_layer1)
cnn_layer1 = Conv2D(filters=64, kernel_size=3, activation='relu')(cnn_layer1)
cnn_layer1 = MaxPooling2D(pool_size=2)(cnn_layer1)
cnn_layer1 = Conv2D(filters=128, kernel_size=3, activation='relu')(cnn_layer1)
cnn_layer1 = MaxPooling2D(pool_size=2)(cnn_layer1)

# Define CNN 2 layers with input shape of 224x224x3
cnn_input2 = Input(shape=(224, 224, 3))
# Create 3 convolutional layers with 32, 64 and 128 filters respectively and kernel size of 3x3 and activation function of relu
# Create 3 max pooling layers with pool size of 2x2
cnn_layer2 = Conv2D(filters=32, kernel_size=3, activation='relu')(cnn_input2)
cnn_layer2 = MaxPooling2D(pool_size=2)(cnn_layer2)
cnn_layer2 = Conv2D(filters=64, kernel_size=3, activation='relu')(cnn_layer2)
cnn_layer2 = MaxPooling2D(pool_size=2)(cnn_layer2)
cnn_layer2 = Conv2D(filters=128, kernel_size=3, activation='relu')(cnn_layer2)
cnn_layer2 = MaxPooling2D(pool_size=2)(cnn_layer2)

# Concatenate CNN Model 1 and CNN Model 2 layers
concat_layer = tf.keras.layers.Concatenate()([cnn_layer1, cnn_layer2])

# Pooling layer to reduce dimensionality
pooled = GlobalMaxPooling2D()(concat_layer)
# Define ViT layers with input shape of 224x224x3
vit_input = Input(shape=(pooled.shape[1], 256, 3))  # change the last dimension to 3
# Create ViT model with pretrained weights
vit_layer1 = vit.vit_b16(
    image_size=256,
    classes=2,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
)(vit_input)
flatten = Flatten()(vit_layer1)
# Create final output layer with 2 classes and sigmoid activation function
output = Dense(units=2, activation='sigmoid')(flatten)
# Define the model
model = tf.keras.models.Model(inputs=[cnn_input1, cnn_input2, vit_input], outputs=output)
# Compile the model with binary crossentropy loss, Adam optimizer and accuracy metric
model.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(0.0001), metrics=["accuracy"])
# Print model summary
model.summary()



Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 vit-b16 (Functional)           (None, 1000)         86613736    ['input_3[0][0]']                
                                                                                                  
 flatten (Flatten)              (None, 1000)         0           ['vit-b16[0][0]']                
                                                                                                  
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                           

Model Plotting

In [9]:
# Install libraries for model Plotting
!pip install pydot
!pip install pydotplus 
!pip install graphviz



In [10]:
# Import libraries for model plotting
from keras.utils.vis_utils import plot_model
# Plot model
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


Labels are set with 0, 1 category

In [11]:
# convert labels to set of 0 and 1 
from keras.utils import to_categorical

# One-hot encode the labels
train_labels_one_hot = to_categorical(train_labels)
valid_labels_one_hot = to_categorical(valid_labels)
train_labels_one_hot.shape, valid_labels_one_hot.shape

((12078, 2), (452, 2))

Model Training

In [12]:
# Train the model with list of data because model has 3 inputs
history = model.fit([train_data_224, train_data_224, train_data_256], train_labels_one_hot, epochs=2, batch_size=4, 
                    validation_data=([valid_data_224, valid_data_224, valid_data_256], valid_labels_one_hot))

Epoch 1/2
 254/3020 [=>............................] - ETA: 11:17:57 - loss: 0.7085 - accuracy: 0.5089

In [None]:
# save model after training into disk
model.save('model.h5')

ROC Curve

In [None]:
# plot the training and validation accuracy and loss at each epoch
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# Evaluate model on validation set
loss, accuracy = model.evaluate([valid_data, valid_data, valid_data256], valid_labels_one_hot)

# Get predicted probabilities for each class
probs = model.predict([valid_data, valid_data, valid_data256])

# Calculate False Positive Rate, True Positive Rate, and thresholds for the positive class
fpr, tpr, thresholds = roc_curve(valid_labels_one_hot[:, 1], probs[:, 1])
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic Curve')
plt.legend(loc="lower right")
plt.show()