#Imports

In [None]:
import os
import numpy as np
import cv2  # OpenCV for image IO
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import pandas as pd
import albumentations as A

#Config

In [None]:
# Basic configuration: image size, directories, and mask class values
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 1
INPUT_DIR = 'drive/MyDrive/DIP/BW'
GT_DIR = 'drive/MyDrive/DIP/ground_truth'
PREDICTIONS_DIR = 'drive/MyDrive/DIP/predicted_masks_final'
CELL_PIXEL_VALUES = [51, 102, 255]  # Pre-determined mask intensity values

#Transforms

In [None]:
# Validation transform: resize mask and image to target dimensions
val_transform = A.Compose([
    A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH,
             interpolation=cv2.INTER_NEAREST, always_apply=True)
])

#ModelLoad

In [None]:
# Ensure the U-Net model is available and loaded with best weights
model_load_path = 'unet_hela_best.keras'
if 'model' not in locals() or model is None:
    if os.path.exists(model_load_path):
        print(f"Loading model weights from: {model_load_path}")
        # Rebuild U-Net architecture if needed
        from tensorflow.keras import layers
        def conv_block(i, n):
            x = layers.Conv2D(n, 3, padding='same')(i)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)
            x = layers.Conv2D(n, 3, padding='same')(x)
            x = layers.BatchNormalization()(x)
            return layers.Activation('relu')(x)
        def encoder_block(i, n):
            c = conv_block(i, n)
            p = layers.MaxPooling2D((2,2))(c)
            return c, p
        def decoder_block(i, s, n):
            x = layers.Conv2DTranspose(n, 2, strides=(2,2), padding='same')(i)
            x = layers.Concatenate(axis=-1)([x, s])
            return conv_block(x, n)
        def build_unet(input_shape):
            inputs = keras.Input(shape=input_shape)
            s1, p1 = encoder_block(inputs, 64)
            s2, p2 = encoder_block(p1, 128)
            s3, p3 = encoder_block(p2, 256)
            s4, p4 = encoder_block(p3, 512)
            b1 = conv_block(p4, 1024)
            d1 = decoder_block(b1, s4, 512)
            d2 = decoder_block(d1, s3, 256)
            d3 = decoder_block(d2, s2, 128)
            d4 = decoder_block(d3, s1, 64)
            outputs = layers.Conv2D(1, 1, activation='sigmoid')(d4)
            return keras.Model(inputs, outputs, name='U-Net')
        model = build_unet((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
        model.load_weights(model_load_path)
        print("Model rebuilt and weights loaded.")
    else:
        print(f"Error: Model file {model_load_path} not found.")
        sys.exit()
else:
    print("Using pre-loaded model variable.")

#Helpers

In [None]:
# Load an image in grayscale mode, return None on failure
def load_grayscale_image(path):
    try:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Warning: Could not read {path}")
        return img
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return None

#Paths

In [None]:
# Pair input and GT images by filename, ensuring exact match
def get_image_paths(input_dir, gt_dir):
    inputs = sorted(f for f in os.listdir(input_dir) if f.endswith('.png'))
    gts = sorted(f for f in os.listdir(gt_dir) if f.endswith('.png'))
    if len(inputs) != len(gts):
        raise ValueError("Input/GT PNG count mismatch.")
    paired = [(os.path.join(input_dir, f), os.path.join(gt_dir, f))
              for f in inputs]
    return zip(*paired)

#Metrics

In [None]:
# Functions to compute TP, TN, FP, FN and derive IoU, Dice, F1, accuracy

def numpy_calculate_stats(y_true, y_pred):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    TP = np.sum((y_true_f == 1) & (y_pred_f == 1))
    TN = np.sum((y_true_f == 0) & (y_pred_f == 0))
    FP = np.sum((y_true_f == 0) & (y_pred_f == 1))
    FN = np.sum((y_true_f == 1) & (y_pred_f == 0))
    return TP, TN, FP, FN

def numpy_jaccard(TP, TN, FP, FN, smooth=1e-6):
    return (TP + smooth) / (TP + FP + FN + smooth)

def numpy_dice(TP, TN, FP, FN, smooth=1e-6):
    return (2 * TP + smooth) / (2 * TP + FP + FN + smooth)

def numpy_f1_score(TP, TN, FP, FN, smooth=1e-6):
    precision = (TP + smooth) / (TP + FP + smooth)
    recall = (TP + smooth) / (TP + FN + smooth)
    return (2 * precision * recall + smooth) / (precision + recall + smooth)

def numpy_accuracy(TP, TN, FP, FN):
    total = TP + TN + FP + FN
    return (TP + TN) / total if total else 0.0

#Main

In [None]:
if __name__ == '__main__':
    # Prepare output folder and gather image pairs
    os.makedirs(PREDICTIONS_DIR, exist_ok=True)
    inputs, gts = get_image_paths(INPUT_DIR, GT_DIR)

    # Containers for per-image metrics
    iou_scores, dice_scores, f1_scores, acc_scores = [], [], [], []

    # Iterate over images: preprocess, predict, save, and evaluate
    for idx, (inp, gt) in enumerate(zip(inputs, gts), 1):
        img = load_grayscale_image(inp)
        mask_gt = load_grayscale_image(gt)
        if img is None or mask_gt is None:
            print(f"Skipping {inp}")
            continue
        # Preprocess input
        tr = val_transform(image=img)
        inp_norm = tr['image'].astype(np.float32) / 255.0
        inp_tensor = inp_norm[None, ..., None]
        # Model prediction
        pred_prob = model.predict(inp_tensor, verbose=0)[0]
        pred_bin = (pred_prob > 0.5).astype(np.float32)
        # Save binary mask
        save_path = os.path.join(PREDICTIONS_DIR, os.path.basename(inp))
        cv2.imwrite(save_path, (pred_bin.squeeze() * 255).astype(np.uint8))
        # Prepare GT for metrics
        gt_tr = val_transform(image=mask_gt)
        gt_bin = np.isin(gt_tr['image'], CELL_PIXEL_VALUES).astype(np.float32)
        # Compute stats and metrics
        TP, TN, FP, FN = numpy_calculate_stats(gt_bin, pred_bin)
        iou_scores.append(numpy_jaccard(TP, TN, FP, FN))
        dice_scores.append(numpy_dice(TP, TN, FP, FN))
        f1_scores.append(numpy_f1_score(TP, TN, FP, FN))
        acc_scores.append(numpy_accuracy(TP, TN, FP, FN))

    # Plot all metrics over image index
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 8))
    plt.plot(iou_scores, label='IoU')
    plt.plot(dice_scores, label='Dice')
    plt.plot(f1_scores, label='F1')
    plt.plot(acc_scores, label='Accuracy')
    plt.legend(); plt.title('Metrics per Image'); plt.xlabel('Image Index'); plt.ylabel('Score')
    plt.show()

    # Summarize metrics in a DataFrame
    df = pd.DataFrame({
        'IoU': iou_scores,
        'Dice': dice_scores,
        'F1': f1_scores,
        'Accuracy': acc_scores
    })
    summary = df.agg(['mean', 'median', 'std']).round(4)
    print(summary)
