#Imports

In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random
import albumentations as A


#Config

In [None]:
# Image dimensions and channels
IMG_WIDTH = 256
IMG_HEIGHT = 256
IMG_CHANNELS = 1

# Directories for input images and ground-truth masks
INPUT_DIR = 'drive/MyDrive/DIP/BW'
GT_DIR = 'drive/MyDrive/DIP/ground_truth'

# Class values in the mask (hard-coded)
CELL_PIXEL_VALUES = [51, 102, 255]

# Training hyperparameters
BATCH_SIZE = 8
EPOCHS = 100
PATIENCE = 15

@Helpers
def load_grayscale_image(path):
    """
    Load an image in grayscale mode.
    Returns the image array or None if loading fails.
    """
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"Warning: Could not read {path}")
    return img

@Paths
def get_image_paths(input_dir, gt_dir):
    """
    Collect and pair PNG filenames from input and ground-truth directories.
    Filenames must match exactly; else raises an error.
    """
    inputs = sorted(f for f in os.listdir(input_dir) if f.endswith('.png'))
    gts = sorted(f for f in os.listdir(gt_dir) if f.endswith('.png'))
    if len(inputs) != len(gts):
        raise ValueError("Mismatched counts of input and ground truth images.")
    # Return full file paths for both lists
    return ([os.path.join(input_dir, f) for f in inputs],
            [os.path.join(gt_dir, f) for f in gts])

#DataGen

In [None]:
class DataGenerator(keras.utils.Sequence):
    """
    Keras Sequence for batch-wise loading and augmentation of images and masks.
    Applies augmentations if provided, normalizes images, and binarizes masks.
    """
    def __init__(self, image_paths, mask_paths, cell_values,
                 batch_size=BATCH_SIZE, dim=(IMG_HEIGHT, IMG_WIDTH),
                 n_channels=IMG_CHANNELS, shuffle=True, augmentations=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.cell_values = cell_values
        self.batch_size = batch_size
        self.dim = dim
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.augmentations = augmentations
        self.on_epoch_end()

    def __len__(self):
        # Number of batches per epoch
        return len(self.image_paths) // self.batch_size

    def __getitem__(self, idx):
        # Generate one batch of data
        idxs = self.indexes[idx*self.batch_size : (idx+1)*self.batch_size]
        batch_imgs = [self.image_paths[i] for i in idxs]
        batch_masks = [self.mask_paths[i] for i in idxs]
        return self._generate(batch_imgs, batch_masks)

    def on_epoch_end(self):
        # Shuffle indexes after each epoch
        self.indexes = np.arange(len(self.image_paths))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def _generate(self, imgs, masks):
        # Initialize arrays for images and masks
        X = np.zeros((self.batch_size, *self.dim, self.n_channels), dtype=np.float32)
        y = np.zeros((self.batch_size, *self.dim, 1), dtype=np.float32)
        for i, (img_p, mask_p) in enumerate(zip(imgs, masks)):
            img = load_grayscale_image(img_p)
            m = load_grayscale_image(mask_p)
            # Apply augmentations if provided, else just resize
            if self.augmentations:
                aug = self.augmentations(image=img, mask=m)
                img, m = aug['image'], aug['mask']
            else:
                img = cv2.resize(img, self.dim)
                m = cv2.resize(m, self.dim, interpolation=cv2.INTER_NEAREST)
            # Normalize image and binarize mask based on class values
            X[i,...,0] = img / 255.0
            y[i,...,0] = np.isin(m, self.cell_values).astype(np.float32)
        return X, y

#Augmentations

In [None]:
# Define data augmentation for training and basic resize for validation
train_transform = A.Compose([
    A.Resize(IMG_HEIGHT, IMG_WIDTH, interpolation=cv2.INTER_NEAREST),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=35, p=0.3, border_mode=cv2.BORDER_CONSTANT),
    A.ElasticTransform(p=0.3, alpha=120, sigma=6, alpha_affine=3),
    A.RandomBrightnessContrast(p=0.3)
])
val_transform = A.Compose([
    A.Resize(IMG_HEIGHT, IMG_WIDTH, interpolation=cv2.INTER_NEAREST)
])

#Model

In [None]:
def build_unet(input_shape):
    """
    Constructs a U-Net model for binary segmentation.

    Encoder: repeated conv + pool
    Decoder: transpose conv, skip-concat, and conv
    Final layer: 1x1 conv with sigmoid activation
    """
    def conv_block(x, filters):
        x = layers.Conv2D(filters, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Conv2D(filters, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        return layers.ReLU()(x)

    def encoder(x, f):
        c = conv_block(x, f)
        p = layers.MaxPool2D()(c)
        return c, p

    def decoder(x, skip, f):
        x = layers.Conv2DTranspose(f, 2, strides=2, padding='same')(x)
        x = layers.Concatenate()([x, skip])
        return conv_block(x, f)

    inputs = keras.Input((*input_shape,))
    s1, p1 = encoder(inputs, 64)
    s2, p2 = encoder(p1, 128)
    s3, p3 = encoder(p2, 256)
    s4, p4 = encoder(p3, 512)
    b = conv_block(p4, 1024)
    d1 = decoder(b, s4, 512)
    d2 = decoder(d1, s3, 256)
    d3 = decoder(d2, s2, 128)
    d4 = decoder(d3, s1, 64)
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(d4)
    return keras.Model(inputs, outputs, name='U-Net')

#Training

In [None]:
if __name__ == '__main__':
    # Load and split data, create generators
    img_paths, mask_paths = get_image_paths(INPUT_DIR, GT_DIR)
    train_gen, val_gen, test_x, test_y = None, None, [], []
    # Use helper function to split and create generators
    from sklearn.model_selection import train_test_split
    train_gen, val_gen, test_x, test_y = (
        DataGenerator(*train_test_split(img_paths, mask_paths, test_size=0.3, random_state=42, stratify=None),
                      cell_values=CELL_PIXEL_VALUES,
                      augmentations=train_transform),
        DataGenerator(*train_test_split(img_paths, mask_paths, test_size=0.3, random_state=42, stratify=None),
                      cell_values=CELL_PIXEL_VALUES,
                      shuffle=False, augmentations=val_transform),
        [], []  # test lists built later
    )
    # Build and compile model
    model = build_unet((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    # Callbacks: save best, adjust LR, early stop
    callbacks = [
        keras.callbacks.ModelCheckpoint('best.h5', save_best_only=True),
        keras.callbacks.ReduceLROnPlateau(patience=PATIENCE//2),
        keras.callbacks.EarlyStopping(patience=PATIENCE, restore_best_weights=True)
    ]
    # Train
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=EPOCHS,
        callbacks=callbacks
    )
    # Evaluate on test set if available
    if test_x:
        X_test, y_test = [], []
        for xi, mi in zip(test_x, test_y):
            img = load_grayscale_image(xi)
            m = load_grayscale_image(mi)
            t = val_transform(image=img, mask=m)
            X_test.append(t['image'][...,None] / 255)
            y_test.append(np.isin(t['mask'], CELL_PIXEL_VALUES)[...,None])
        X_test, y_test = np.array(X_test), np.array(y_test)
        print(model.evaluate(X_test, y_test))
    # Plot training loss curves
    plt.plot(history.history['loss'], label='loss')
    plt.plot(history.history.get('val_loss', []), label='val_loss')
    plt.legend()
    plt.show()