# Semantic Segmentation with U-Net
- Understand pixel-wise classification
- Build an encoder–decoder network
- Understand skip connections
- Train a lightweight U-Net

## What Is Semantic Segmentation?
- Instead of: Image → Label
- We can now predict: Image → Mask (same height & width)
- Each pixel gets a class label.

In [25]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds

# -----------------------------
# Load Oxford-IIIT Pet dataset
# -----------------------------
# Use latest available version (4.0.0)
dataset, info = tfds.load(
    "oxford_iiit_pet:4.0.0",
    with_info=True,
    as_supervised=False,  # we handle images/masks manually
    shuffle_files=True
)

train_ds = dataset["train"]
val_ds = dataset["test"]

# -----------------------------
# Preprocessing function
# -----------------------------
def preprocess(sample):
    # Resize image & mask
    image = tf.image.resize(sample["image"], (128, 128))
    mask = tf.image.resize(sample["segmentation_mask"], (128, 128), method="nearest")
    
    # Normalize image to [0,1]
    image = tf.cast(image, tf.float32) / 255.0
    
    # Ensure mask is integer type and zero-indexed
    mask = tf.cast(mask, tf.int32)
    mask = mask - 1  # original mask values are 1-3; shift to 0-2
    
    return image, mask

# -----------------------------
# Apply preprocessing, batching, prefetching
# -----------------------------
batch_size = 16

train_ds = (
    train_ds
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (
    val_ds
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

# -----------------------------
# Quick sanity check
# -----------------------------
for images, masks in train_ds.take(1):
    print("Images shape:", images.shape)
    print("Masks shape:", masks.shape)
    print("Mask unique values:", tf.unique(tf.reshape(masks, [-1]))[0].numpy())



Images shape: (16, 128, 128, 3)
Masks shape: (16, 128, 128, 1)
Mask unique values: [1 2 0]


In [26]:
# Build Lightweight U-Net
# We use: MobileNetV2 encoder (pretrained), transposed convolutions for upsampling, skip connections

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(128,128,3),
    include_top=False,
    weights="imagenet"
)

base_model.trainable = False

# Extract skip layers
layer_names = [
    "block_1_expand_relu",
    "block_3_expand_relu",
    "block_6_expand_relu",
    "block_13_expand_relu",
    "block_16_project"
]

layers = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False

# Unsampling path

def upsample(filters, size):
    return tf.keras.Sequential([
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding="same"),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU()
    ])

# Construct U-Net:

inputs = tf.keras.Input(shape=(128,128,3))
skips = down_stack(inputs)
x = skips[-1]
skips = reversed(skips[:-1])

up_stack = [
    upsample(512, 3),
    upsample(256, 3),
    upsample(128, 3),
    upsample(64, 3)
]

for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

outputs = tf.keras.layers.Conv2DTranspose(
    3, 3, strides=2, padding="same", activation="softmax"
)(x)

model = tf.keras.Model(inputs, outputs)


In [27]:
# Compile

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)


In [28]:
# Train

model.fit(train_ds.take(200), epochs=3)


Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x216fc0f1a80>

In [34]:
# Visualizing Predictions

import os
import tensorflow as tf

# -----------------------------
# Directory to save results
# -----------------------------
RESULTS_DIR = "segmentation_results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# -----------------------------
# Color mapping for masks
# -----------------------------
COLOR_MAP = tf.constant([
    [0, 0, 0],       # class 0 = background (black)
    [255, 0, 0],     # class 1 = class 1 (red)
    [0, 255, 0],     # class 2 = class 2 (green)
], dtype=tf.uint8)

def apply_color_mask(mask_2d):
    """
    Convert 2D mask of class indices to RGB color mask using COLOR_MAP.
    """
    mask_2d = tf.cast(mask_2d, tf.int32)
    color_mask = tf.gather(COLOR_MAP, mask_2d)
    return tf.cast(color_mask, tf.uint8)

# -----------------------------
# Function to save a single sample
# -----------------------------
def save_colored_sample(image, mask, model, index, alpha=0.5):
    """
    Save input image, true mask overlay, and predicted mask overlay.

    Args:
        image: tf.Tensor, shape (H, W, 3)
        mask: tf.Tensor, shape (H, W) or (H, W, 1)
        model: trained segmentation model
        index: int, sample index for filenames
        alpha: float, blending factor for overlay
    """
    # Ensure image is uint8
    if image.dtype != tf.uint8:
        image_uint8 = tf.image.convert_image_dtype(image, tf.uint8)
    else:
        image_uint8 = image

    # Flatten mask to 2D if needed
    mask_2d = mask[...,0] if mask.shape[-1] == 1 else mask

    # True mask overlay
    true_color_mask = apply_color_mask(mask_2d)
    alpha_tf = tf.constant(alpha, dtype=tf.float32)
    true_overlay = tf.cast(alpha_tf * tf.cast(true_color_mask, tf.float32) +
                           (1 - alpha_tf) * tf.cast(image_uint8, tf.float32),
                           tf.uint8)

    # Predicted mask overlay
    pred_mask_logits = model.predict(tf.expand_dims(image, axis=0))
    pred_mask = tf.argmax(pred_mask_logits[0], axis=-1)
    pred_color_mask = apply_color_mask(pred_mask)
    pred_overlay = tf.cast(alpha_tf * tf.cast(pred_color_mask, tf.float32) +
                           (1 - alpha_tf) * tf.cast(image_uint8, tf.float32),
                           tf.uint8)

    # Save images
    tf.io.write_file(os.path.join(RESULTS_DIR, f"sample_{index}_input.jpg"),
                     tf.io.encode_jpeg(image_uint8))
    tf.io.write_file(os.path.join(RESULTS_DIR, f"sample_{index}_true_overlay.jpg"),
                     tf.io.encode_jpeg(true_overlay))
    tf.io.write_file(os.path.join(RESULTS_DIR, f"sample_{index}_pred_overlay.jpg"),
                     tf.io.encode_jpeg(pred_overlay))

    print(f"Saved sample {index} images to {RESULTS_DIR}")

# -----------------------------
# Loop to save N samples from dataset
# -----------------------------
N = 5  # number of samples to save
sample_count = 0

for images_batch, masks_batch in train_ds:
    batch_size = images_batch.shape[0]
    for i in range(batch_size):
        if sample_count >= N:
            break
        save_colored_sample(images_batch[i], masks_batch[i], model, index=sample_count)
        sample_count += 1
    if sample_count >= N:
        break


Saved sample 0 images to segmentation_results
Saved sample 1 images to segmentation_results
Saved sample 2 images to segmentation_results
Saved sample 3 images to segmentation_results
Saved sample 4 images to segmentation_results


| Final | Conceptual | Comparison |
|---|---|---|
| Classification | One label | Low |
| Detection | Boxes + labels | Medium |
| Segmentation | Pixel-level mask	| High |
