# AI-Powered Medical Image Analysis â€” Training Notebook (EfficientNetB0)

This notebook prepares data, trains the model, evaluates it, and exports artifacts.

In [None]:
import os, json, math, itertools
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, precision_recall_curve, auc
from sklearn.utils.class_weight import compute_class_weight

tf.random.set_seed(42)
IMAGE_SIZE = (224,224)
BATCH_SIZE = 16
DATA_ROOT = "data/samples"
OUTPUT_DIR = "models/v1"
os.makedirs(OUTPUT_DIR, exist_ok=True)


## Build datasets

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(DATA_ROOT, "train"),
    labels="inferred", label_mode="binary",
    image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode="rgb", shuffle=True, seed=42
)
val_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(DATA_ROOT, "val"),
    labels="inferred", label_mode="binary",
    image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode="rgb", shuffle=False
)
test_ds = tf.keras.utils.image_dataset_from_directory(
    os.path.join(DATA_ROOT, "test"),
    labels="inferred", label_mode="binary",
    image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode="rgb", shuffle=False
)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y)).prefetch(AUTOTUNE)
val_ds = val_ds.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y)).prefetch(AUTOTUNE)
test_ds = test_ds.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y)).prefetch(AUTOTUNE)
class_names = [ 'NORMAL', 'PNEUMONIA' ]
print(class_names)


## Model & Augmentation

In [None]:
def augmentation():
    return tf.keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.03),
        layers.RandomZoom(0.1),
        layers.RandomContrast(0.15),
    ])

base = tf.keras.applications.efficientnet.EfficientNetB0(include_top=False, weights="imagenet", input_shape=IMAGE_SIZE+(3,))
base.trainable = False
inputs = layers.Input(shape=IMAGE_SIZE+(3,))
x = augmentation()(inputs)
x = base(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(1, activation='sigmoid')(x)
model = models.Model(inputs, outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='binary_crossentropy',
              metrics=[tf.keras.metrics.AUC(name='auc'),
                       tf.keras.metrics.BinaryAccuracy(name='acc'),
                       tf.keras.metrics.Precision(name='precision'),
                       tf.keras.metrics.Recall(name='recall')])
model.summary()


## Class weights

In [None]:
y_all = []
for _, y in train_ds.unbatch():
    y_all.append(int(y.numpy()))
import numpy as np
y_all = np.array(y_all)
classes = np.unique(y_all)
from sklearn.utils.class_weight import compute_class_weight
weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_all)
class_weights = {int(c): float(w) for c, w in zip(classes, weights)}
class_weights


## Train (frozen)

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(os.path.join(OUTPUT_DIR, "best.h5"),
                                       monitor="val_auc", mode="max", save_best_only=True),
    tf.keras.callbacks.EarlyStopping(monitor="val_auc", mode="max", patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6),
    tf.keras.callbacks.CSVLogger(os.path.join(OUTPUT_DIR, "training_log.csv")),
]
hist = model.fit(train_ds, validation_data=val_ds, epochs=5, class_weight=class_weights, callbacks=callbacks)


## Fine-tune last blocks

In [None]:
base.trainable = True
for layer in base.layers[:-40]:
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), loss='binary_crossentropy',
              metrics=[tf.keras.metrics.AUC(name='auc'),
                       tf.keras.metrics.BinaryAccuracy(name='acc'),
                       tf.keras.metrics.Precision(name='precision'),
                       tf.keras.metrics.Recall(name='recall')])
hist2 = model.fit(train_ds, validation_data=val_ds, epochs=10, class_weight=class_weights, callbacks=callbacks)


## Evaluation, Confusion Matrix, ROC/PR

In [None]:
import numpy as np
y_true = np.concatenate([y.numpy().ravel() for _, y in test_ds], axis=0)
y_prob = model.predict(test_ds).ravel()
y_pred = (y_prob >= 0.5).astype(int)

cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:\n", cm)
print("\nClassification Report:\n", classification_report(y_true, y_pred, target_names=class_names))

roc = roc_auc_score(y_true, y_prob)
print("ROC AUC:", roc)

precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
pr_auc = auc(recall, precision)
print("PR AUC:", pr_auc)

# Plot confusion matrix
fig = plt.figure(figsize=(4,4))
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()


## Save model

In [None]:
model.save(os.path.join(OUTPUT_DIR, "model.h5"))
model.save(os.path.join(OUTPUT_DIR, "saved_model"))
print("Saved to", OUTPUT_DIR)


## Grad-CAM utility

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name=None):
    # attempt to locate last conv layer automatically
    if last_conv_layer_name is None:
        last_conv_layer_name = None
        for layer in reversed(model.layers):
            if isinstance(layer, tf.keras.layers.Conv2D):
                last_conv_layer_name = layer.name
                break
        if last_conv_layer_name is None:
            # try base model
            for layer in reversed(model.layers):
                if hasattr(layer, 'layers'):
                    for l2 in reversed(layer.layers):
                        if isinstance(l2, tf.keras.layers.Conv2D):
                            last_conv_layer_name = l2.name
                            break
                    if last_conv_layer_name:
                        break

    grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(last_conv_layer_name).output, model.output])
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        loss = predictions[:, 0]
    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-6)
    return heatmap.numpy()
