# U-Net Model for Segmentation

We will convert this notebook into script since this is just the model itself, no data is attached. We need to compile and train our model and then record loss and metrics with our predictions

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model

In [2]:
def double_conv (x, c_out): #x is the input tensor, c_out is the number of output channels
    x = layers.Conv2D(c_out, 3, padding="same", use_bias =False)(x) # 3x3 2D convolution with equal output height/width
    x = layers.BatchNormalization()(x) # Normalizes the feature maps so that each channel has a stable mean and variance
    x = layers.ReLU()(x) # ReLU activation
    x = layers.Conv2D(c_out, 3, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [3]:
def UNetSmallTF(input_shape=(256,256,1), num_classes=3, base=32): # takes 256x256 greyscale images with 3 classes and a baseline of 32 filters
    inputs = layers.Input(shape=input_shape) # input tensor


    #Encoder - each stage downsamples via MaxPooling, reducing spacial resolution and extracting features
    
    c1 = double_conv(inputs, base)
    p1 = layers.MaxPool2D(2)(c1) # 128 x 128 x 32, MaxPooling reduces resolution by 2x

    c2 = double_conv(p1, base*2) # 64 filters
    p2 = layers.MaxPool2D(2)(c2) # 64 x 64 x 64

    c3 = double_conv(p2, base*4) # 128 filters
    p3 = layers.MaxPool2D(2)(c3) # 32 x 32 x 128

    # Bottleneck
    bn = double_conv(p3, base*8) # stops at 256 filters since it's at the bottom of the U

    # Decoder - upsampling the features back into the input resolution
    
    u3 = layers.Conv2DTranspose(base*4, 2, strides=2, padding="same")(bn) # learnable upsampling, double spacial size 32x32 -> 64x64
    u3 = layers.Concatenate()([u3, c3]) # concatenate with c3 to reintroduce spatial features using, c3 is a "skip connection"
    u3 = double_conv(u3, base*4) 

    u2 = layers.Conv2DTranspose(base*2, 2, strides=2, padding="same")(u3)
    u2 = layers.Concatenate()([u2, c2])
    u2 = double_conv(u2, base*2)

    u1 = layers.Conv2DTranspose(base, 2, strides=2, padding="same")(u2)
    u1 = layers.Concatenate()([u1, c1])
    u1 = double_conv(u1, base)

    logits = layers.Conv2D(num_classes, 1, padding="same")(u1)  # a 1x1 convolution acts as a linear classifier at each pixel
    return Model(inputs, logits, name="UNetSmallTF") 

In [4]:
model = UNetSmallTF(num_classes=3)
model.summary()

## Otsu Method for Generated Masks

Now that we have our model, we need to attach data and generate pseudo_masks since we aren't going to be using ground-truth masks for our model 

In [35]:
import os, glob
import cv2
import numpy as np
from tqdm import tqdm # shows progress bar in notebook, we don't necessarily need this but it's nice
from skimage import filters, exposure, morphology, measure
from skimage.filters import threshold_otsu, threshold_multiotsu # import multi otsu in case we need it
from skimage.morphology import remove_small_holes, remove_small_objects, binary_closing, disk
DATA_PATH = "E:/DOCUMENTS/CSULB/FALL 2025/CECS 361/project/ML/chest_xray/train"
img_dir = os.path.join(DATA_PATH, "normal")
mask_dir = os.path.join(DATA_PATH, "pseudo_masks")
os.makedirs(mask_dir, exist_ok=True)


image_size = (256, 256)
num_classes = 2
min_obj_px = 64 # remove tiny specks
CLAHE = True #adaptive contrasting

## Creating our Pseudo Masks

In [37]:
def preprocess_grey(img):
    if CLAHE:
        clahe = cv2.createCLAHE(clipLimit = 2.0, tileGridSize=(8,8))
        img = clahe.apply(img)
    return cv2.GaussianBlur(img, (3,3), 0) # standard deviation is automatically chosen based on the kernel size

def binary_otsu_mask(grey_u8):
    t = threshold_otsu(grey_u8) # compute the otsu threshold on the preprocessed image
    mask = (grey_u8 > t).astype(np.uint8) # create a binary mask using the threshold, 

    mask = binary_closing(mask, footprint = disk(3))
    mask = remove_small_objects(mask.astype(bool), min_obj_px) # remove isolated regions of noise
    mask = remove_small_holes(mask, min_obj_px) # remove holes inside solid objects
    
    return (mask.astype(np.uint8))

img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
print("Found", len(img_paths), "images")

for p in tqdm(img_paths):
    img = cv2.imread(p, cv2.IMREAD_UNCHANGED) # reading image from path p, loading the image as is
    if img is None: # skips the iteration if image fails to load
        continue
    if img.ndim == 3: # converts image to greyscale if it find an image with 3 channels (RGB)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
    img_r = cv2.resize(img, image_size[::-1], interpolation = cv2.INTER_LINEAR)
    img_p = preprocess_grey(img_r.astype(np.uint8)) # converts resized images to 8-bit unsigned integers
    
    m = binary_otsu_mask(img_p)
    m8 = (m * 255).astype(np.uint8)
    
    base = os.path.splitext(os.path.basename(p))[0]
    np.save(os.path.join(mask_dir, base + "npy"), m)
    
    save_path = os.path.join(mask_dir, os.path.basename(p)) # saves the original image file name, but puts it in mask directory
    cv2.imwrite(save_path, m8) # single channel PNG image

Found 1341 images


100%|██████████████████████████████████████████████████████████████████████████████| 1341/1341 [00:30<00:00, 43.44it/s]


## Data Pipeline for Loading Images

Now that we have preprocessed all of our images, we now need to load them onto our model, so we need to create a datapipeline that does that for us

In [38]:
img_paths = img_paths = sorted([p for ext in ("*.jpg","*.jpeg","*.png") for p in glob.glob(os.path.join(img_dir, ext))])
mask_paths = [os.path.join(mask_dir, os.path.splitext(os.path.basename(p))[0] + ".npy") for p in img_paths]
pairs = [(i, m) for i,m in zip(img_paths, mask_paths) if os.path.exists(m)]
random.Random(SEED).shuffle(pairs)

def load_mask(path):
    m = np.load(path.decode("utf-8"))
    if m.ndim == 2:
        m = m[..., None]
    return m.astype("float32")

def parse(img_path, mask_path):
    img = tf.io.read.file(img_path)
    img =  tf.io.decode_image (img, channels = 1, expand_animations = False)
    img = tf.image.resize(img, (H, W), method="bilinear")
    img = tf.cast(img, tf.float32 / 255.00)


## Training / Testing the Model

In [None]:
batch = 8
seed = 42
val_split = 0.2
epochs = 40

n_val = int(len(pairs)