<a href="https://colab.research.google.com/github/mohandaz/HIA-Research-Project/blob/main/TB_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jul  7 23:28:34 2024

@author: Mohan
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.layers import Flatten, Dropout, Dense, Input
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.applications import VGG16
import tensorflow as tf
from skimage.exposure import equalize_adapthist
from skimage.filters import gaussian
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, accuracy_score, recall_score, precision_score, f1_score, roc_auc_score
import seaborn as sns
from sklearn.model_selection import train_test_split

# Database: CSV file and image folder
csv_path = '/Volumes/MHIA/Mix/mix_data.csv'
image_dir = "/Volumes/MHIA/Mix/xray"

metadata_df = pd.read_csv(csv_path)
metadata_df.reset_index(inplace=True)
print("Variables:")
print(metadata_df.columns)

print("\nNull entries check:")
print(metadata_df.isnull().sum())

# Add .png extension to study_id (Note: not required if study_id includes file extension (.png))
metadata_df["study_id"] = metadata_df["study_id"].astype(str) + ".png"

# Create labels
metadata_df["labels"] = metadata_df["findings"].copy()
metadata_df["labels"] = metadata_df["labels"].str.lower()

# Get all image arrays
img_arrs = {name_: None for name_ in metadata_df["study_id"]}
corrupted_files = []
for name_ in tqdm(img_arrs.keys()):
    if name_.startswith("._"):
        corrupted_files.append(name_)
        continue
    try:
        image = load_img(
            os.path.join(image_dir, name_),
            target_size=(224, 224), color_mode="grayscale"
        )
        image_arr = img_to_array(image, data_format="channels_last", dtype=None)
        img_arrs[name_] = image_arr
    except Exception as e:
        print(f"Error loading {name_}: {e}")
        corrupted_files.append(name_)

# Remove corrupted files
metadata_df = metadata_df[~metadata_df["study_id"].isin(corrupted_files)]

# Preprocess images
img_arrs_preprocessed = {name_: {} for name_ in img_arrs.keys() if img_arrs[name_] is not None}
for name_, arr_ in tqdm(img_arrs.items()):
    if arr_ is None:
        continue

    # Original
    img_arrs_preprocessed[name_]['original'] = arr_

    # Normalized
    arr_normalized = arr_ / arr_.max()
    img_arrs_preprocessed[name_]['normalized'] = arr_normalized

    # Histogram equalization
    arr_clahe = equalize_adapthist(arr_normalized, kernel_size=15, clip_limit=0.05)
    img_arrs_preprocessed[name_]['equalized'] = arr_clahe

    # Gaussian smoothing
    arr_gaussian = gaussian(arr_clahe, sigma=0.5)
    img_arrs_preprocessed[name_]['smoothed'] = arr_gaussian

    # Scaling
    arr_scaled = tf.image.per_image_standardization(arr_gaussian)
    img_arrs_preprocessed[name_]['scaled'] = arr_scaled.numpy()

# Prepare the final dataset
images = np.array([img_arrs_preprocessed[name_]['scaled'] for name_ in img_arrs_preprocessed.keys()])

# Convert 'findings' to binary labels: 'tb' -> 1, 'normal' -> 0
labels = metadata_df["findings"].apply(lambda x: 1 if x.lower() == 'tb' else 0).values

# Ensure all images have the correct shape
print("Shapes of preprocessed images:")
for img in images[:5]:
    print(img.shape)

# Convert grayscale images to 3 channels for VGG16
def convert_to_rgb(x):
    return np.repeat(x, 3, axis=-1)

X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=None)

X_train_rgb = np.stack([convert_to_rgb(img) for img in X_train])
X_test_rgb = np.stack([convert_to_rgb(img) for img in X_test])

print(f"Shape of X_train_rgb: {X_train_rgb.shape}")
print(f"Shape of X_test_rgb: {X_test_rgb.shape}")

# Verify the shape of one sample image
print(f"Shape of one sample image: {X_train_rgb[0].shape}")

# Define the VGG model
input_shape = (224, 224, 3)
base_model = VGG16(
    weights='imagenet',
    include_top=False,
    input_shape=input_shape
)

model = Sequential()
model.add(Input(shape=input_shape))  # Explicitly define input shape
model.add(base_model)
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

base_model.trainable = False

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.00009), loss='binary_crossentropy', metrics=['accuracy'])

# Ensure model summary works
model.summary()

# Define the EarlyStopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Data Augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

# Train the Model
train_data_gen = datagen.flow(X_train_rgb, y_train, batch_size=32)

history = model.fit(
    train_data_gen,
    validation_data=(X_test_rgb, y_test),
    epochs=60,
    callbacks=[early_stopping]
)


# Plot training history
fig, ax = plt.subplots(1, 2, figsize=(10, 6))
ax[0].plot(history.history['loss'], 'r', label='train')
ax[0].plot(history.history['val_loss'], 'b', label='val')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend()

ax[1].plot(history.history['accuracy'], 'r', label='train')
ax[1].plot(history.history['val_accuracy'], 'b', label='val')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Accuracy')
ax[1].legend()

plt.show()

# Evaluate on test data set
results = model.evaluate(X_test_rgb, y_test, batch_size=64)
print("Test loss, Test acc:", results)

# Predictions on test set
y_pred = model.predict(X_test_rgb)
y_pred_classes = (y_pred > 0.5).astype("int32")


# Visuzlisation

# Confusion matrix
cm = confusion_matrix(y_test, y_pred_classes)

# Plot
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Normal', 'TB'], yticklabels=['Normal', 'TB'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# Class-wise performance metrics
print(classification_report(y_test, y_pred_classes, target_names=['Normal', 'TB']))

# ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

# Calculate metrics
def calculate_metrics(y_true, y_pred, y_pred_prob):
    acc = accuracy_score(y_true, y_pred)
    auroc = roc_auc_score(y_true, y_pred_prob)
    sensitivity = recall_score(y_true, y_pred)
    specificity = recall_score(y_true, y_pred, pos_label=0)
    ppv = precision_score(y_true, y_pred)
    npv = precision_score(y_true, y_pred, pos_label=0)
    f1 = f1_score(y_true, y_pred)
    return acc, auroc, sensitivity, specificity, ppv, npv, f1

# Predictions on training set
y_train_pred = model.predict(X_train_rgb)
y_train_pred_classes = (y_train_pred > 0.5).astype("int32")

# Calculate metrics for training set
train_acc, train_auroc, train_sensitivity, train_specificity, train_ppv, train_npv, train_f1 = calculate_metrics(y_train, y_train_pred_classes, y_train_pred)

# Calculate metrics for testing set
test_acc, test_auroc, test_sensitivity, test_specificity, test_ppv, test_npv, test_f1 = calculate_metrics(y_test, y_pred_classes, y_pred)

# Prepare the data for the table
data = {
    'Metric': ['AUROC', 'Accuracy', 'Sensitivity', 'Specificity', 'PPV', 'NPV', 'F1'],
    'Training': [train_auroc, train_acc, train_sensitivity, train_specificity, train_ppv, train_npv, train_f1],
    'Testing': [test_auroc, test_acc, test_sensitivity, test_specificity, test_ppv, test_npv, test_f1]
}

# Create the DataFrame
metrics_df = pd.DataFrame(data)

# Display the table
print(metrics_df)


# Save the model to the specified path in .h5 format
save_path = '/Volumes/MHIA/Mix/mix_model_VGG3.h5'
model.save(save_path, save_format='h5')
print(f"Model saved successfully to {save_path}.")


# Save the model in the recommended Keras format
save_path_keras = '/Volumes/MHIA/Mix/mix_model_VGG3.keras'
model.save(save_path_keras)
print(f"Model saved successfully to {save_path_keras}.")


################ to print sample image ###########

import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np

# Select 3 sample images
sample_images = list(img_arrs_preprocessed.keys())[:3]
print(f"Selected sample images: {sample_images}")

# Define a function to plot preprocessing steps
def plot_preprocessing_steps(samples):
    fig, axes = plt.subplots(nrows=len(samples), ncols=5, figsize=(20, 12))
    steps = ['original', 'normalized', 'equalized', 'smoothed', 'scaled']

    for i, sample in enumerate(samples):
        for j, step in enumerate(steps):
            axes[i, j].imshow(img_arrs_preprocessed[sample][step].reshape(224, 224), cmap='gray')
            axes[i, j].set_title(step.capitalize(), fontsize=10)
            axes[i, j].axis('off')
            axes[i, j].set_xlabel(f'{sample}', fontsize=8)
            axes[i, j].set_ylabel(f'Image {i+1}', fontsize=8)

    plt.tight_layout()
    plt.show()

# Plot the preprocessing steps
plot_preprocessing_steps(sample_images)

# Define the data generator for augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

# Function to plot augmentation steps
def plot_augmentation_steps(samples):
    fig, axes = plt.subplots(nrows=len(samples), ncols=6, figsize=(24, 12))
    augmentation_labels = ['Original', 'Rotation', 'Width Shift', 'Height Shift', 'Zoom', 'Horizontal Flip']

    for i, sample in enumerate(samples):
        original_image = np.expand_dims(img_arrs_preprocessed[sample]['scaled'], axis=-1)
        axes[i, 0].imshow(original_image.reshape(224, 224), cmap='gray')
        axes[i, 0].set_title('Original', fontsize=10)
        axes[i, 0].axis('off')
        axes[i, 0].set_xlabel(f'{sample}', fontsize=8)
        axes[i, 0].set_ylabel(f'Image {i+1}', fontsize=8)

        for j in range(1, 6):
            aug_image = datagen.flow(np.expand_dims(img_arrs_preprocessed[sample]['scaled'], axis=0), batch_size=1).__next__()[0]
            axes[i, j].imshow(aug_image.reshape(224, 224), cmap='gray')
            axes[i, j].set_title(augmentation_labels[j], fontsize=10)
            axes[i, j].axis('off')

    plt.tight_layout()
    plt.show()

# Plot the augmentation steps
plot_augmentation_steps(sample_images)







