In [None]:
# PNEUMONIA DETECTION FROM CHEST X-RAYS - IMPROVED WORKING CODE
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.utils import class_weight
import os

# 1. SETUP ======================================================
IMG_SIZE = (224, 224)  # ResNet input size
BATCH_SIZE = 32
EPOCHS = 15  # Increased epochs to allow for early stopping
LR = 0.0001

# 2. DATA PIPELINE ==============================================
base_dir = '/content/chest_xray' # Corrected base directory
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

# Add checks for directory existence and content
print(f"Checking if {base_dir} exists: {os.path.exists(base_dir)}")
if os.path.exists(base_dir):
    print(f"Contents of {base_dir}: {os.listdir(base_dir)}")
    print(f"Checking if {train_dir} exists: {os.path.exists(train_dir)}")
    if os.path.exists(train_dir):
        print(f"Contents of {train_dir}: {os.listdir(train_dir)}")
    print(f"Checking if {val_dir} exists: {os.path.exists(val_dir)}")
    if os.path.exists(val_dir):
        print(f"Contents of {val_dir}: {os.listdir(val_dir)}")
    print(f"Checking if {test_dir} exists: {os.path.exists(test_dir)}")
    if os.path.exists(test_dir):
        print(f"Contents of {test_dir}: {os.listdir(test_dir)}")
else:
    print("Base directory does not exist. Please ensure the dataset is unzipped correctly.")


# Enhanced data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.15,
    height_shift_range=0.15,
    shear_range=0.15,
    zoom_range=0.15,
    horizontal_flip=True,
    vertical_flip=True,  # Added vertical flip
    fill_mode='nearest'
)

test_val_datagen = ImageDataGenerator(rescale=1./255)

# Create generators with balanced classes
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=True,
    seed=42  # Fixed seed for reproducibility
)

val_generator = test_val_datagen.flow_from_directory(
    val_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False
)

test_generator = test_val_datagen.flow_from_directory(
    test_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    shuffle=False
)

# Calculate class weights to handle imbalance
class_weights = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(train_generator.classes),
    y=train_generator.classes
)
class_weights = dict(enumerate(class_weights))
print(f"Class weights: {class_weights}")  # Debug output

# 3. IMPROVED MODEL BUILDING ====================================
base_model = ResNet50(
    weights='imagenet',
    include_top=False,
    input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3),
    pooling='avg'  # Directly add global pooling
)

base_model.trainable = False

# Simplified model architecture
model = models.Sequential([
    layers.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
    base_model,
    layers.Dropout(0.3),  # Reduced dropout
    layers.Dense(128, activation='relu'),  # Smaller dense layer
    layers.Dense(1, activation='sigmoid')
])

# Mixed precision training for better performance
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

model.compile(
    optimizer=optimizers.Adam(LR),
    loss='binary_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.TruePositives(name='tp'),
        tf.keras.metrics.FalsePositives(name='fp')
    ]
)

# 4. ENHANCED TRAINING ==========================================
callbacks = [
    callbacks.EarlyStopping(
        monitor='val_auc',  # Changed to monitor AUC
        patience=5,
        mode='max',
        restore_best_weights=True
    ),
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=2,
        min_lr=1e-7
    ),
    callbacks.ModelCheckpoint(
        'best_model.h5',
        monitor='val_auc',
        save_best_only=True,
        mode='max'
    )
]

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_data=val_generator,
    validation_steps=val_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=callbacks,
    class_weight=class_weights,  # Added class weights
    verbose=2
)

# 5. COMPREHENSIVE EVALUATION ===================================
def plot_history(history):
    plt.figure(figsize=(18, 6))

    # Accuracy
    plt.subplot(1, 3, 1)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Validation')
    plt.title('Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()

    # AUC
    plt.subplot(1, 3, 2)
    plt.plot(history.history['auc'], label='Train')
    plt.plot(history.history['val_auc'], label='Validation')
    plt.title('AUC')
    plt.ylabel('AUC')
    plt.xlabel('Epoch')
    plt.legend()

    # Loss
    plt.subplot(1, 3, 3)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_history(history)

# Load best model for final evaluation
model = models.load_model('best_model.h5')

# Test evaluation
test_results = model.evaluate(test_generator)
print("\nFINAL TEST METRICS:")
print(f"Accuracy: {test_results[1]:.4f}")
print(f"Precision: {test_results[2]:.4f}")
print(f"Recall: {test_results[3]:.4f}")
print(f"AUC: {test_results[4]:.4f}")
print(f"True Positives: {test_results[5]}")
print(f"False Positives: {test_results[6]}")

# Enhanced confusion matrix
test_generator.reset()
y_true = test_generator.classes
y_pred = model.predict(test_generator) > 0.5

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Normal', 'Pneumonia'],
            yticklabels=['Normal', 'Pneumonia'],
            annot_kws={"size": 16})
plt.title('Confusion Matrix', fontsize=14)
plt.xlabel('Predicted', fontsize=12)
plt.ylabel('Actual', fontsize=12)
plt.show()

# Detailed classification report
print("\nCLASSIFICATION REPORT:")
print(classification_report(y_true, y_pred, target_names=['Normal', 'Pneumonia']))

# 6. MODEL SAVING ===============================================
model.save('pneumonia_detection.h5')
print("Model saved as pneumonia_detection.h5")

# Save the class indices for deployment
import json
class_indices = train_generator.class_indices
with open('class_indices.json', 'w') as f:
    json.dump(class_indices, f)
print("Class indices saved to class_indices.json")

Checking if /content/chest_xray exists: True
Contents of /content/chest_xray: ['val', 'train', 'test']
Checking if /content/chest_xray/train exists: True
Contents of /content/chest_xray/train: ['NORMAL', 'PNEUMONIA']
Checking if /content/chest_xray/val exists: True
Contents of /content/chest_xray/val: ['NORMAL', 'PNEUMONIA']
Checking if /content/chest_xray/test exists: True
Contents of /content/chest_xray/test: ['NORMAL', 'PNEUMONIA']
Found 5216 images belonging to 2 classes.
Found 16 images belonging to 2 classes.
Found 624 images belonging to 2 classes.
Class weights: {0: np.float64(1.9448173005219984), 1: np.float64(0.6730322580645162)}
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  self._warn_if_super_not_called()


Epoch 1/15




163/163 - 696s - 4s/step - accuracy: 0.5015 - auc: 0.5130 - fp: 634.0000 - loss: 0.7112 - precision: 0.7507 - recall: 0.4926 - tp: 1909.0000 - val_accuracy: 0.6250 - val_auc: 0.8516 - val_fp: 6.0000 - val_loss: 0.6777 - val_precision: 0.5714 - val_recall: 1.0000 - val_tp: 8.0000 - learning_rate: 1.0000e-04
Epoch 2/15




163/163 - 690s - 4s/step - accuracy: 0.5684 - auc: 0.6072 - fp: 551.0000 - loss: 0.6759 - precision: 0.7979 - recall: 0.5613 - tp: 2175.0000 - val_accuracy: 0.5000 - val_auc: 0.8594 - val_fp: 0.0000e+00 - val_loss: 0.6800 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_tp: 0.0000e+00 - learning_rate: 1.0000e-04
Epoch 3/15
163/163 - 754s - 5s/step - accuracy: 0.6246 - auc: 0.6840 - fp: 457.0000 - loss: 0.6521 - precision: 0.8386 - recall: 0.6126 - tp: 2374.0000 - val_accuracy: 0.8125 - val_auc: 0.8516 - val_fp: 0.0000e+00 - val_loss: 0.6474 - val_precision: 1.0000 - val_recall: 0.6250 - val_tp: 5.0000 - learning_rate: 1.0000e-04
Epoch 4/15
163/163 - 701s - 4s/step - accuracy: 0.6495 - auc: 0.7138 - fp: 416.0000 - loss: 0.6382 - precision: 0.8555 - recall: 0.6356 - tp: 2463.0000 - val_accuracy: 0.8125 - val_auc: 0.8594 - val_fp: 0.0000e+00 - val_loss: 0.6336 - val_precision: 1.0000 - val_recall: 0.6250 - val_tp: 5.0000 - learning_rate: 1.0000e-04
Epoch 5/15




163/163 - 704s - 4s/step - accuracy: 0.6856 - auc: 0.7434 - fp: 417.0000 - loss: 0.6234 - precision: 0.8641 - recall: 0.6844 - tp: 2652.0000 - val_accuracy: 0.8125 - val_auc: 0.8672 - val_fp: 2.0000 - val_loss: 0.6154 - val_precision: 0.7778 - val_recall: 0.8750 - val_tp: 7.0000 - learning_rate: 1.0000e-04
Epoch 6/15
163/163 - 726s - 4s/step - accuracy: 0.7136 - auc: 0.7812 - fp: 411.0000 - loss: 0.6004 - precision: 0.8717 - recall: 0.7205 - tp: 2792.0000 - val_accuracy: 0.7500 - val_auc: 0.8672 - val_fp: 2.0000 - val_loss: 0.6009 - val_precision: 0.7500 - val_recall: 0.7500 - val_tp: 6.0000 - learning_rate: 1.0000e-04
Epoch 7/15


In [11]:
# For very large datasets (avoids re-uploading)
if not os.path.exists("/content/chest_xray"):
    !unzip "/content/drive/MyDrive/rep/chest_xray.zip" -d "/content"
else:
    print("Dataset already extracted")

unzip:  cannot find or open /content/drive/MyDrive/rep/chest_xray.zip, /content/drive/MyDrive/rep/chest_xray.zip.zip or /content/drive/MyDrive/rep/chest_xray.zip.ZIP.


In [10]:
# Check zip contents without extracting
!unzip -l "/content/drive/MyDrive/rep/chest_xray.zip"

# Fix permission issues
!chmod -R 755 "/content/chest_xray"

unzip:  cannot find or open /content/drive/MyDrive/rep/chest_xray.zip, /content/drive/MyDrive/rep/chest_xray.zip.zip or /content/drive/MyDrive/rep/chest_xray.zip.ZIP.
chmod: cannot access '/content/chest_xray': No such file or directory


In [14]:
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Unzip dataset (run only once)
!unzip -q "/content/drive/MyDrive/rep/chest_xray.zip" -d "/content"

# Verify dataset structure
base_dir = '/content/chest_xray'
required_folders = ['train', 'val', 'test']
required_classes = ['NORMAL', 'PNEUMONIA']

print("🔍 Verifying dataset structure...")
for folder in required_folders:
    folder_path = os.path.join(base_dir, folder)
    if not os.path.exists(folder_path):
        raise FileNotFoundError(f"Missing folder: {folder_path}")

    for class_name in required_classes:
        class_path = os.path.join(folder_path, class_name)
        if not os.path.exists(class_path):
            raise FileNotFoundError(f"Missing class folder: {class_path}")

        num_images = len(os.listdir(class_path))
        print(f"✅ {folder}/{class_name}: {num_images} images")

print("\n🎉 Dataset structure verified successfully!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
🔍 Verifying dataset structure...
✅ train/NORMAL: 1341 images
✅ train/PNEUMONIA: 3875 images
✅ val/NORMAL: 8 images
✅ val/PNEUMONIA: 8 images
✅ test/NORMAL: 234 images
✅ test/PNEUMONIA: 390 images

🎉 Dataset structure verified successfully!
