# Training a Neural Network for Disease Detection in Chest X-rays

Time estimate: **30** minutes

## Objectives

After completing this lab, you will be able to:

 - Build and train a fully connected neural network using the ChestMNIST dataset from MedMNIST
 - Predict multiple pathologies from chest X-ray images through multi-label classification

## What you will do in this lab

In this lab, you will work with the ChestMNIST dataset from MedMNIST to develop and test a neural network model for multi-label medical image classification.

You will:
- Load and preprocess medical imaging data using MedMNIST
- Design a dense neural network for multi-label classification
- Handle medical image data with multiple simultaneous conditions
- Train a model with appropriate hyperparameters for medical imaging
- Evaluate multi-label classification performance using clinically relevant metrics


## Overview

Multi-label classification in medical imaging presents unique challenges compared to traditional single-label classification tasks. In real-world clinical scenarios, patients often present with multiple concurrent pathologies visible in a single chest X-ray. This lab introduces you to building neural networks capable of identifying multiple diseases simultaneously from medical images.

You will work with the ChestMNIST dataset, a standardized subset of chest X-ray images from the MedMNIST collection. This dataset contains images labeled with multiple thoracic pathologies, making it ideal for learning multi-label classification techniques. Unlike standard classification where each image belongs to one category, multi-label classification requires the model to predict the presence or absence of several conditions independently.

Throughout this lab, you will build a fully connected (dense) neural network architecture suitable for processing medical images. You will learn how to properly preprocess medical imaging data, configure appropriate loss functions and evaluation metrics for multi-label scenarios, and interpret model performance in a clinically meaningful way.

Understanding these techniques is essential for developing AI systems that can assist radiologists in detecting multiple pathologies and improving diagnostic accuracy in clinical practice.

By the end of this lab, you will have hands-on experience with the complete workflow of developing a multi-label medical image classification system, from data loading through model evaluation, preparing you for more advanced deep learning applications in healthcare.


## About the dataset

In this lab, you will use the ChestMNIST dataset, which is derived from the NIH Chest X-ray dataset.

- **Dataset Size**: 78,468 chest X-ray images (28×28 grayscale)
- **Task**: Multi-label binary classification
- **Classes**: 14 thoracic disease labels
  1. Atelectasis
  2. Cardiomegaly
  3. Effusion
  4. Infiltration
  5. Mass
  6. Nodule
  7. Pneumonia
  8. Pneumothorax
  9. Consolidation
  10. Edema
  11. Emphysema
  12. Fibrosis
  13. Pleural Thickening
  14. Hernia

## Setup

### Installation commands

In [None]:
!pip -q install medmnist tensorflow numpy matplotlib seaborn scikit-learn


### Importing required libraries

In [None]:
### Import required Libraries

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, multilabel_confusion_matrix, roc_auc_score, hamming_loss, f1_score
import tensorflow as tf
from tensorflow import keras

## Step 1: Load ChestMNIST dataset

In [None]:
import medmnist
from medmnist import INFO

In [None]:
# Get dataset information
data_flag = 'chestmnist'
info = INFO[data_flag]
print(f"\nDataset: {info['python_class']}")
print(f"Task: {info['task']}")
print(f"Number of Classes: {info['n_channels']} input channels, {len(info['label'])} labels")

In [None]:
# Load the dataset
from medmnist import ChestMNIST

In [None]:
# Load train, validation, and test sets
train_dataset = ChestMNIST(split='train', download=True)
val_dataset = ChestMNIST(split='val', download=True)
test_dataset = ChestMNIST(split='test', download=True)

print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}\n\n")

## Step 2: Explore the data

In [None]:
# Get a sample image and labels
img, label = train_dataset[0]

# Convert PIL Image to NumPy array to access shape
img_array = np.array(img)

print(f"Image shape: {img_array.shape}")  # (28, 28, 1)
print(f"Label shape: {label.shape}")  # (14,)
print(f"Label values: {label}")
print(f"Data type: {img_array.dtype}")
print(f"Pixel value range: [{img_array.min()}, {img_array.max()}]")

In [None]:
# Visualize sample images
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
disease_names = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']

for idx, ax in enumerate(axes.flat):
    if idx < len(train_dataset):
        img, label = train_dataset[idx]
        # Convert PIL Image to NumPy array before displaying
        img_array = np.array(img)
        ax.imshow(img_array.squeeze(), cmap='gray')

        # Show which diseases are present
        diseases_present = [disease_names[i] for i, val in enumerate(label) if val == 1]
        title = ', '.join(diseases_present) if diseases_present else 'No findings'
        ax.set_title(title[:30], fontsize=8)  # Truncate long titles
        ax.axis('off')

plt.tight_layout()
plt.suptitle('Sample Chest X-rays with Labels', y=1.02, fontsize=14)
plt.show()

## Step 3: Analyze label distribution

In [None]:
# Extract all labels from training set
all_labels = np.array([train_dataset[i][1] for i in range(len(train_dataset))])

In [None]:
# Count positive cases for each disease
disease_counts = all_labels.sum(axis=0)
disease_percentages = (disease_counts / len(train_dataset)) * 100

# Visualize label distribution
plt.figure(figsize=(14, 6))
bars = plt.bar(range(len(disease_names)), disease_counts)
plt.xticks(range(len(disease_names)), disease_names, rotation=45, ha='right')
plt.ylabel('Number of Positive Cases')
plt.title('Distribution of Diseases in Training Set')
plt.grid(axis='y', alpha=0.3)

# Add percentage labels on bars
for i, (bar, pct) in enumerate(zip(bars, disease_percentages)):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{pct:.1f}%', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

# Check for class imbalance
print("\nClass Distribution:")
for name, count, pct in zip(disease_names, disease_counts, disease_percentages):
    print(f"{name:20s}: {int(count):5d} ({pct:5.2f}%)")

# Average number of conditions per image
avg_conditions = all_labels.sum(axis=1).mean()
print(f"\nAverage conditions per X-ray: {avg_conditions:.2f}")

## Step 4: Convert to NumPy arrays

In [None]:
# Convert datasets to numpy arrays
def dataset_to_numpy(dataset):
    images = []
    labels = []
    for img, label in dataset:
        images.append(img)
        labels.append(label)
    return np.array(images), np.array(labels)

X_train, y_train = dataset_to_numpy(train_dataset)
X_val, y_val = dataset_to_numpy(val_dataset)
X_test, y_test = dataset_to_numpy(test_dataset)

print(f"X_train shape: {X_train.shape}")  # (n_samples, 28, 28, 1)
print(f"y_train shape: {y_train.shape}")  # (n_samples, 14)

## Step 5: Flatten images for dense network

In [None]:
# Flatten images from (28, 28, 1) to (784,)
X_train_flat = X_train.reshape(X_train.shape[0], -1)
X_val_flat = X_val.reshape(X_val.shape[0], -1)
X_test_flat = X_test.reshape(X_test.shape[0], -1)

print(f"X_train_flat shape: {X_train_flat.shape}")  # (n_samples, 784)

## Step 6: Normalize pixel values

In [None]:
# Normalize to [0, 1] range
X_train_norm = X_train_flat.astype('float32') / 255.0
X_val_norm = X_val_flat.astype('float32') / 255.0
X_test_norm = X_test_flat.astype('float32') / 255.0

print(f"Normalized range: [{X_train_norm.min()}, {X_train_norm.max()}]")

## Step 7: Verify data

In [None]:
print("\n=== Data Summary ===")
print(f"Training set: {X_train_norm.shape[0]} samples")
print(f"Validation set: {X_val_norm.shape[0]} samples")
print(f"Test set: {X_test_norm.shape[0]} samples")
print(f"Input features: {X_train_norm.shape[1]}")
print(f"Output labels: {y_train.shape[1]}")
print(f"Label type: {y_train.dtype}")

## Step 8: Define model architecture

In [None]:
# Define input shape
input_dim = X_train_norm.shape[1]  # 784
output_dim = y_train.shape[1]  # 14

In [None]:
# Build the dense neural network
model = keras.Sequential([
    # Explicit Input layer
    keras.Input(shape=(input_dim,)),

    # First hidden block
    keras.layers.Dense(512, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.2),

    # Hidden layer 1
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.2),

    # Hidden layer 2
    keras.layers.Dense(128, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.1),

    # Hidden layer 3
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dropout(0.1),

    # Output layer (multi-label classification)
    keras.layers.Dense(output_dim, activation='sigmoid')
])

In [None]:
# Display model summary
model.summary()

In [None]:
# Calculate total parameters
total_params = model.count_params()
print(f"\nTotal trainable parameters: {total_params:,}")

### Architecture explanation:
- **Input**: 784 features (28×28 flattened image)
- **Layer 1**: 512 neurons with ReLU, Batch Normalization, 50% dropout
- **Layer 2**: 256 neurons with ReLU, Batch Normalization, 40% dropout
- **Layer 3**: 128 neurons with ReLU, Batch Normalization, 30% dropout
- **Layer 4**: 64 neurons with ReLU, 20% dropout
- **Output**: 14 neurons with sigmoid activation (independent binary predictions)

**Why sigmoid?** Each disease is independent, so you use sigmoid for each output (not softmax).

In [None]:

def multilabel_accuracy(y_true, y_pred):
    threshold = tf.cast(0.5, dtype=y_pred.dtype)
    y_pred_binary = tf.cast(y_pred > threshold, dtype=y_pred.dtype)
    correct_predictions = tf.equal(tf.cast(y_true, dtype=tf.float32), y_pred_binary)
    return tf.reduce_mean(tf.cast(correct_predictions, dtype=tf.float32))

## Step 9: Compile the model

In [None]:
# For multi-label classification, use binary crossentropy
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='binary_crossentropy',
    metrics=[
        multilabel_accuracy,
        keras.metrics.AUC(name='auc', multi_label=True),
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall')
    ]
)

## Step 10: Calculate class weights (Handle imbalance)

In [None]:
# Calculate positive class frequency for each label
pos_freq = y_train.sum(axis=0) / len(y_train)
neg_freq = 1 - pos_freq

In [None]:
# Calculate class weights (inverse frequency)
pos_weight = neg_freq / pos_freq
print("Positive class weights per disease:")
for name, weight in zip(disease_names, pos_weight):
    print(f"{name:20s}: {weight:.2f}")

In [None]:
# Create a weighted binary crossentropy loss
def weighted_binary_crossentropy(pos_weight):
    def loss(y_true, y_pred):
        # Clip predictions to prevent log(0)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)

        # Calculate weighted loss
        loss = -tf.reduce_mean(
            pos_weight * y_true * tf.math.log(y_pred) +
            (1 - y_true) * tf.math.log(1 - y_pred)
        )
        return loss
    return loss

In [None]:
# Recompile with weighted loss
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=weighted_binary_crossentropy(pos_weight),
    metrics=[
        multilabel_accuracy,
        keras.metrics.AUC(name='auc', multi_label=True),
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall')
    ]
)

## Step 11: Set up callbacks

In [None]:
# Early stopping
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=15,
    mode='max',
    restore_best_weights=True,
    verbose=1
)

# Learning rate reduction
reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=1e-7,
    verbose=1
)

# Model checkpoint
checkpoint = keras.callbacks.ModelCheckpoint(
    'best_chestmnist_model.keras',
    monitor='val_auc',
    mode='max',
    save_best_only=True,
    verbose=1
)

## Step 12: Train the model

In [None]:
# Train the model
history = model.fit(
    X_train_norm, y_train,
    validation_data=(X_val_norm, y_val),
    epochs=30,
    batch_size=128,
    callbacks=[early_stopping, reduce_lr, checkpoint],
    verbose=1
)

print(f"\nTraining completed!")
print(f"Best epoch: {np.argmax(history.history['val_auc']) + 1}")

## Step 13: Visualize training history

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Accuracy
axes[0, 0].plot(history.history['multilabel_accuracy'], label='Train')
axes[0, 0].plot(history.history['val_multilabel_accuracy'], label='Validation')
axes[0, 0].set_title('Model Accuracy')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('multilabel_accuracy')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Loss
axes[0, 1].plot(history.history['loss'], label='Train')
axes[0, 1].plot(history.history['val_loss'], label='Validation')
axes[0, 1].set_title('Model Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# AUC
axes[1, 0].plot(history.history['auc'], label='Train')
axes[1, 0].plot(history.history['val_auc'], label='Validation')
axes[1, 0].set_title('Model AUC')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('AUC')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Precision & Recall
axes[1, 1].plot(history.history['precision'], label='Train Precision')
axes[1, 1].plot(history.history['val_precision'], label='Val Precision')
axes[1, 1].plot(history.history['recall'], label='Train Recall')
axes[1, 1].plot(history.history['val_recall'], label='Val Recall')
axes[1, 1].set_title('Precision & Recall')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Step 14: Make predictions

In [None]:
# Predict on test set (probabilities)
y_pred_prob = model.predict(X_test_norm)

# Convert probabilities to binary predictions (threshold = 0.5)
y_pred = (y_pred_prob > 0.5).astype(int)

print(f"Prediction probabilities shape: {y_pred_prob.shape}")
print(f"Binary predictions shape: {y_pred.shape}")

## Step 15: Calculate overall metrics

In [None]:
# Calculate overall metrics
test_loss, test_acc, test_auc, test_precision, test_recall = model.evaluate(
    X_test_norm, y_test, verbose=0
)

print("\n=== Overall Test Set Performance ===")
print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_acc:.4f}")
print(f"AUC: {test_auc:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")



## Step 16: Calculate per-disease performance

In [None]:
# Calculate per-disease metrics
print("\n=== Per-Disease Performance ===")
print(f"{'Disease':<20s} {'AUC':>6s} {'Precision':>10s} {'Recall':>8s} {'F1':>6s} {'Support':>8s}")
print("=" * 70)

per_disease_metrics = []
for i, disease in enumerate(disease_names):
    # AUC for this disease
    auc = roc_auc_score(y_test[:, i], y_pred_prob[:, i])

    # Precision, Recall, F1
    tp = ((y_test[:, i] == 1) & (y_pred[:, i] == 1)).sum()
    fp = ((y_test[:, i] == 0) & (y_pred[:, i] == 1)).sum()
    fn = ((y_test[:, i] == 1) & (y_pred[:, i] == 0)).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    support = y_test[:, i].sum()

    print(f"{disease:<20s} {auc:>6.3f} {precision:>10.3f} {recall:>8.3f} {f1:>6.3f} {int(support):>8d}")

    per_disease_metrics.append({
        'disease': disease,
        'auc': auc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'support': int(support)
    })

## Step 17: Visualize confusion matrices for each disease

In [None]:
# Calculate multilabel confusion matrix
cm = multilabel_confusion_matrix(y_test, y_pred)

# Visualize confusion matrices for first 4 diseases
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for i in range(4):
    sns.heatmap(cm[i], annot=True, fmt='d', cmap='Blues', ax=axes[i],
                xticklabels=['Negative', 'Positive'],
                yticklabels=['Negative', 'Positive'])
    axes[i].set_title(f'{disease_names[i]} Confusion Matrix')
    axes[i].set_ylabel('Actual')
    axes[i].set_xlabel('Predicted')

plt.tight_layout()
plt.show()

## Step 18: Plot ROC curves

In [None]:
from sklearn.metrics import roc_curve

# Plot ROC curves for selected diseases
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.ravel()

# Select 6 most common diseases
selected_indices = np.argsort(y_test.sum(axis=0))[-6:]

for idx, i in enumerate(selected_indices):
    fpr, tpr, _ = roc_curve(y_test[:, i], y_pred_prob[:, i])
    auc_score = roc_auc_score(y_test[:, i], y_pred_prob[:, i])

    axes[idx].plot(fpr, tpr, color='darkorange', lw=2,
                   label=f'ROC curve (AUC = {auc_score:.3f})')
    axes[idx].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    axes[idx].set_xlim([0.0, 1.0])
    axes[idx].set_ylim([0.0, 1.05])
    axes[idx].set_xlabel('False Positive Rate')
    axes[idx].set_ylabel('True Positive Rate')
    axes[idx].set_title(f'{disease_names[i]} ROC Curve')
    axes[idx].legend(loc="lower right")
    axes[idx].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Step 19: Visualize predictions

In [None]:
# Visualize some predictions
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

for idx, ax in enumerate(axes.flat):
    if idx < 12:
        # Get random test sample
        i = np.random.randint(0, len(X_test))
        img = X_test[i].squeeze()
        true_labels = y_test[i]
        pred_probs = y_pred_prob[i]

        # Display image
        ax.imshow(img, cmap='gray')

        # Get disease names
        true_diseases = [disease_names[j] for j, val in enumerate(true_labels) if val == 1]
        pred_diseases = [(disease_names[j], pred_probs[j])
                        for j in range(len(pred_probs)) if pred_probs[j] > 0.5]

        # Create title
        true_str = ', '.join(true_diseases) if true_diseases else 'No findings'
        pred_str = ', '.join([f"{d}({p:.2f})" for d, p in pred_diseases]) if pred_diseases else 'No findings'

        title = f"True: {true_str[:30]}\nPred: {pred_str[:30]}"
        ax.set_title(title, fontsize=8)
        ax.axis('off')

plt.tight_layout()
plt.suptitle('Sample Predictions (Probability > 0.5)', y=1.02, fontsize=14)
plt.show()

## Step 20: Save and deploy the model

In [None]:
# Save entire model
model.save('chestmnist_dnn_model.keras')

# Save model architecture as JSON
model_json = model.to_json()
with open('model_architecture.json', 'w') as json_file:
    json_file.write(model_json)

# Save weights separately
model.save_weights('model_weights.weights.h5')

print("Model saved successfully!")

In [None]:
### Load and Use Model

# Load model
loaded_model = keras.models.load_model(
    'chestmnist_dnn_model.keras',
    custom_objects={'loss': weighted_binary_crossentropy(pos_weight), 'multilabel_accuracy': multilabel_accuracy}
)

In [None]:
# Make prediction on new data
from PIL import Image

def predict_diseases(model, image):
    """
    Predict diseases from a single chest X-ray image

    Args:
        model: trained Keras model
        image: PIL Image or numpy array of shape (28, 28, 1)

    Returns:
        dict: disease names and probabilities
    """
    # Convert PIL Image to NumPy array if necessary
    if isinstance(image, Image.Image):
        image = np.array(image)

    # Preprocess
    img_flat = image.reshape(1, -1).astype('float32') / 255.0

    # Predict
    probs = model.predict(img_flat, verbose=0)[0]

    # Create results dictionary
    results = {}
    for disease, prob in zip(disease_names, probs):
        results[disease] = float(prob)

    return results

In [None]:
# Test prediction function
test_img, test_label = test_dataset[0]
predictions = predict_diseases(loaded_model, test_img)

print("\nPrediction Results:")
for disease, prob in sorted(predictions.items(), key=lambda x: x[1], reverse=True):
    if prob > 0.5:
        print(f"{disease:20s}: {prob:.3f} ({'POSITIVE' if prob > 0.5 else 'negative'})")

# Exercises
In Step 8, you built a neural network. Given below are two more architectures for neural networks. One is a deeper network with more layers. Another one is a wider network with less layers. Replace the existing network with these networks, one after the other and observe the changes in the accuracy.

## Exercise 1: Use the deeper network architecture variation

In [None]:
# Deeper network
model_deep = keras.Sequential([
    keras.layers.Dense(1024, activation='relu', input_shape=(input_dim,)),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(512, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.1),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dropout(0.1),
    keras.layers.Dense(output_dim, activation='sigmoid')
])

## Exercise 2: Use the wider network architecture variation

In [None]:
# Wider network
model_wide = keras.Sequential([
    keras.layers.Dense(1024, activation='relu', input_shape=(input_dim,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(1024, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(output_dim, activation='sigmoid')
])

# Congratulations!

You have successfully completed this lab on how to load and preprocess medical imaging data using MedMNIST, design a dense neural network for multi-label classification, handle medical image datasets involving multiple simultaneous conditions, train models with appropriate hyperparameters optimized for medical imaging, and evaluate multi-label classification performance using clinically relevant metrics.

## Authors

Ramesh Sannareddy

Copyright © 2025 SkillUp. All rights reserved.