In [None]:
!pip install medmnist

In [None]:
# %% [markdown]
# # Training ResNet‑50 (28) on PathMNIST with TensorFlow and CUDA
#
# This notebook demonstrates how to download and preprocess the PathMNIST dataset, define a ResNet‑50 architecture adapted for 28×28 input images, train the model using GPU acceleration with CUDA, evaluate its performance on the test set, and visualize the results using a confusion matrix.

# %% [code]
# Imports and GPU Setup
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Uncomment the following line if medmnist is not installed:
# !pip install medmnist

import medmnist
from medmnist import PathMNIST
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Enable GPU memory growth for CUDA optimization
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print("GPUs detected. Optimized for CUDA.")
else:
    print("No GPU found. Running on CPU.")

# %% [markdown]
# ## 1. Data Acquisition & Loading
# This section downloads the PathMNIST dataset, normalizes the images, converts them to 3 channels if needed, converts labels to one-hot encoding, and creates TensorFlow datasets for training, validation, and testing.
#
# Note: We set `num_classes = 9` as PathMNIST does not provide this attribute.

# %% [code]
# Download and load the PathMNIST dataset
train_dataset = PathMNIST(split='train', download=True)
val_dataset   = PathMNIST(split='val', download=True)
test_dataset  = PathMNIST(split='test', download=True)

# Extract images and labels
X_train, y_train = train_dataset.imgs, train_dataset.labels
X_val,   y_val   = val_dataset.imgs,   val_dataset.labels
X_test,  y_test  = test_dataset.imgs,  test_dataset.labels

# Normalize pixel values to [0, 1]
X_train = X_train.astype('float32') / 255.
X_val   = X_val.astype('float32')   / 255.
X_test  = X_test.astype('float32')  / 255.

# Ensure images have 3 channels (repeat channels if needed)
if X_train.ndim == 3 or X_train.shape[-1] != 3:
    X_train = np.repeat(X_train[..., np.newaxis], 3, axis=-1)
    X_val   = np.repeat(X_val[..., np.newaxis],   3, axis=-1)
    X_test  = np.repeat(X_test[..., np.newaxis],  3, axis=-1)

# Since PathMNIST does not have a num_classes attribute, we explicitly set it to 9.
num_classes = 9

# Convert labels to one-hot encoding
y_train = to_categorical(y_train, num_classes)
y_val   = to_categorical(y_val, num_classes)
y_test  = to_categorical(y_test, num_classes)

print("Training data:", X_train.shape, y_train.shape)
print("Validation data:", X_val.shape, y_val.shape)
print("Test data:", X_test.shape, y_test.shape)

# Create tf.data.Dataset objects for efficient loading
batch_size = 64
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds   = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_ds  = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# %% [markdown]
# ## 2. Model Definition – ResNet‑50 (28)
# In this section, we define a ResNet‑50 model adapted for 28×28 images. Because of the small input size, the initial convolution uses a 3×3 kernel with stride 1 and no max pooling.
#
# We define two building blocks:
# - **identity_block:** Uses three convolution layers where the input is added back to the output.
# - **conv_block:** Similar to the identity block but includes a convolution on the shortcut to match dimensions.
#
# These blocks are then stacked to create the ResNet‑50 (28) architecture.

# %% [code]
# Define an identity block
def identity_block(input_tensor, filters, kernel_size=3):
    filters1, filters2, filters3 = filters
    x = Conv2D(filters1, (1,1), kernel_initializer='he_normal')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters2, (kernel_size, kernel_size), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters3, (1,1), kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    
    x = add([x, input_tensor])
    x = Activation('relu')(x)
    return x

# Define a convolutional block with a shortcut path
def conv_block(input_tensor, filters, kernel_size=3, strides=(2,2)):
    filters1, filters2, filters3 = filters
    x = Conv2D(filters1, (1,1), strides=strides, kernel_initializer='he_normal')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters2, (kernel_size, kernel_size), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(filters3, (1,1), kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    
    shortcut = Conv2D(filters3, (1,1), strides=strides, kernel_initializer='he_normal')(input_tensor)
    shortcut = BatchNormalization()(shortcut)
    
    x = add([x, shortcut])
    x = Activation('relu')(x)
    return x

# Assemble the ResNet-50 (28) model
def ResNet50_28(input_shape=(28,28,3), num_classes=9):
    img_input = Input(shape=input_shape)
    
    # Initial convolutional layer with a smaller kernel and stride for 28x28 inputs
    x = Conv2D(64, (3,3), strides=(1,1), padding='same', kernel_initializer='he_normal')(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    # Note: No max pooling is used because the input image is small
    
    # Stage 2
    x = conv_block(x, filters=[64, 64, 256], strides=(1,1))  # Using stride 1 here to maintain dimensions
    x = identity_block(x, filters=[64, 64, 256])
    x = identity_block(x, filters=[64, 64, 256])
    
    # Stage 3
    x = conv_block(x, filters=[128, 128, 512], strides=(2,2))
    x = identity_block(x, filters=[128, 128, 512])
    x = identity_block(x, filters=[128, 128, 512])
    x = identity_block(x, filters=[128, 128, 512])
    
    # Stage 4
    x = conv_block(x, filters=[256, 256, 1024], strides=(2,2))
    for _ in range(5):
        x = identity_block(x, filters=[256, 256, 1024])
    
    # Stage 5
    x = conv_block(x, filters=[512, 512, 2048], strides=(2,2))
    x = identity_block(x, filters=[512, 512, 2048])
    x = identity_block(x, filters=[512, 512, 2048])
    
    x = GlobalAveragePooling2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=img_input, outputs=outputs)
    return model

# Instantiate and summarize the model
model = ResNet50_28(input_shape=(28,28,3), num_classes=num_classes)
model.summary()

# %% [markdown]
# ## 3. Training
# We compile the model with the Adam optimizer and categorical crossentropy loss, and then train it for 10 epochs while monitoring performance on the validation dataset.

# %% [code]
# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
epochs = 10
history = model.fit(train_ds,
                    epochs=epochs,
                    validation_data=val_ds)

# %% [markdown]
# ## 4. Evaluation
# After training, we evaluate the model’s performance on the test dataset.

# %% [code]
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

# %% [markdown]
# ## 5. Confusion Matrix
# In this final section, we generate predictions on the test set, compute the confusion matrix using scikit-learn, and visualize it with a heatmap.

# %% [code]
from sklearn.metrics import confusion_matrix

# Retrieve true labels and predicted probabilities from the test dataset
y_true = np.concatenate([y for x, y in test_ds], axis=0)
y_true_labels = np.argmax(y_true, axis=1)
y_pred_probs  = model.predict(test_ds)
y_pred_labels = np.argmax(y_pred_probs, axis=1)

# Compute the confusion matrix
cm = confusion_matrix(y_true_labels, y_pred_labels)

# Plot the confusion matrix
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt="d", cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

No GPU found. Running on CPU.
Training data: (89996, 28, 28, 3) (89996, 9)
Validation data: (10004, 28, 28, 3) (10004, 9)
Test data: (7180, 28, 28, 3) (7180, 9)


Epoch 1/10
[1m 189/1407[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m27:48[0m 1s/step - accuracy: 0.4866 - loss: 1.8257