In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
import cv2
import pickle
from pathlib import Path
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configuration
CONFIG = {
    'dataset_root': './dataset2',
    'image_size': (384, 384),
    'batch_size': 4,
    'epochs': 20,
    'learning_rate': 1e-3,
    'train_val_split': 0.8,
    'ignore_label': 255,
    'early_stopping_patience': 10,
    'reduce_lr_patience': 5
}


2026-01-20 02:49:46.808247: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-20 02:49:46.863959: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-20 02:49:47.840651: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


Configuration:
  dataset_root: ./dataset2
  image_size: (384, 384)
  batch_size: 4
  epochs: 20
  learning_rate: 0.001
  train_val_split: 0.8
  ignore_label: 255
  early_stopping_patience: 10
  reduce_lr_patience: 5


In [None]:
#setting that allows tensorflow to switch to cpu computation if GPU is run out of memory, what is happaning all the time
import tensorflow as tf

tf.config.set_soft_device_placement(True)
    
gpus = tf.config.experimental.list_physical_devices('GPU')


* PKL (Pickle) files are Python's way of saving objects to disk(in my case I store class mappings there):


In [None]:
# Load Class Mappings
mappings_path = os.path.join(CONFIG['dataset_root'], "class_mappings_final.pkl")

with open(mappings_path, "rb") as f:
    mappings = pickle.load(f)

NUM_CLASSES = mappings['num_classes']

# Create identity mapping (masks already use indices 0 to NUM_CLASSES-1 (in my case its 14-1))
# So class 0 maps to index 0, class 1 to index 1, etc.

class_id_to_idx = {i: i for i in range(NUM_CLASSES)}
idx_to_class_id = {i: i for i in range(NUM_CLASSES)}

# Keep original mapping for reference
original_class_id_to_idx = mappings['class_id_to_idx']
class_titles = mappings['class_titles'] # Human-readable class names


In [None]:
# train Val Split 
train_mask_dir = os.path.join(CONFIG['dataset_root'], "processed_final", "train", "masks")
mask_files = sorted([f for f in os.listdir(train_mask_dir) if f.endswith('.png')])

# 80/20 from config
train_files, val_files = train_test_split(
    mask_files, 
    test_size=1-CONFIG['train_val_split'], 
    random_state=42
)


Instead of loading all images into memory at once (which would be a problem with large datasets and limited computation resources), the generator loads batches during training when its needed

In [None]:
# data generator with weighted sampling
class SegmentationGenerator(keras.utils.Sequence):
    
    def __init__(self, image_dir, mask_dir, file_list, class_id_to_idx,
                 # store paths, batch size, image size, etc.
                 batch_size=4, image_size=(384, 384), shuffle=True):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.file_list = file_list
        self.class_id_to_idx = class_id_to_idx
        self.batch_size = batch_size
        self.image_size = image_size
        self.shuffle = shuffle
        
        self.indexes = np.arange(len(self.file_list))
        self.on_epoch_end()
    
    def __len__(self):
        # Returns number of batches per epoch
        return int(np.ceil(len(self.file_list) / self.batch_size))
    
    def __getitem__(self, index):
        batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        
        batch_images = []
        batch_masks = []
        
        for idx in batch_indexes:
            img, mask = self._load_sample(idx)
            batch_images.append(img)
            batch_masks.append(mask)
        
        return np.array(batch_images, dtype=np.float32), np.array(batch_masks, dtype=np.uint8)
    
    def _load_sample(self, idx):
        # 1. load image and mask from disk
        # 2. resize them target size (384x384)
        # 3. rmap mask: convert original class id to 0-13 indices
        # 4. normalize image to 0 - 1 range
        # 5. set ignored pixels to 255
        # load and preprocess image and mask
        mask_file = self.file_list[idx]
        img_file = mask_file.replace('.jpg.png', '.jpg')
        
        # load image
        img_path = self.image_dir / img_file
        if not img_path.exists():
            img_file = mask_file
            img_path = self.image_dir / img_file
        
        from PIL import Image
        image = Image.open(str(img_path)).convert('RGB')
        image = np.array(image)
        
        # load mask
        mask_path = self.mask_dir / mask_file
        mask = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)
        
        # resize if needed
        if image.shape[:2] != self.image_size:
            image = cv2.resize(image, (self.image_size[1], self.image_size[0]))
            mask = cv2.resize(mask, (self.image_size[1], self.image_size[0]), 
                            interpolation=cv2.INTER_NEAREST)
        
        # remap classId to indices 0-N
        mask_remapped = np.full_like(mask, 255, dtype=np.uint8)
        for class_id, new_idx in self.class_id_to_idx.items():
            mask_remapped[mask == class_id] = new_idx
        
        # normalize image to [0, 1]
        image = image.astype(np.float32) / 255.0
        
        return image, mask_remapped
    
    def on_epoch_end(self):
# update indexes after each epoch and dont let them from learning sample order
        if self.shuffle:
            np.random.shuffle(self.indexes)


The generators will automatically provide batches in the format the model expects: (batch_size, 384, 384, 3) for images and (batch_size, 384, 384) for masks.

In [None]:
# create Data Generators
train_img_dir = os.path.join(CONFIG['dataset_root'], "resized", "train", "images")

train_gen = SegmentationGenerator(
    train_img_dir, train_mask_dir, train_files,
    class_id_to_idx=class_id_to_idx,
    batch_size=CONFIG['batch_size'],
    image_size=CONFIG['image_size'],
    shuffle=True # prevents overfitting by randomizing sample order each epoch

)

val_gen = SegmentationGenerator(
    train_img_dir, train_mask_dir, val_files,
    class_id_to_idx=class_id_to_idx,
    batch_size=CONFIG['batch_size'],
    image_size=CONFIG['image_size'],
    shuffle=False
)

### Why Double Convolution?

U-Net architecture uses **two consecutive 3x3 convolutions** per block to:
- Extract richer features
- Increase receptive field without large kernels
- Follow the original U-Net paper design

In [None]:
# U-Net encoder block
def encoder_block(x, filters, name, pool=True): #using name argument to get clear model summary
# encoder block with double convolution and optional pooling
    # first convolution block
    x = layers.Conv2D(filters, 3, padding='same', name=f'{name}_conv1')(x)
    x = layers.BatchNormalization(name=f'{name}_bn1')(x)
    x = layers.ReLU(name=f'{name}_relu1')(x)
     # second convolution block
    x = layers.Conv2D(filters, 3, padding='same', name=f'{name}_conv2')(x)
    x = layers.BatchNormalization(name=f'{name}_bn2')(x)
    x = layers.ReLU(name=f'{name}_relu2')(x)
    
    if pool:
        p = layers.MaxPooling2D(2, name=f'{name}_pool')(x)
        return x, p # x = skip connection, p = input to next layer
    else:
        return x #bottleneck doesnt pull, for example

Without skips: blurry segmentation (lost spatial details during downsampling)
With skips: sharp boundaries (encoder's precise localization + decoder's semantics)

In [None]:
# U-Net decoder block
def decoder_block(x, skip, filters, name):
# Decoder block with upsampling and skip connection
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding='same', name=f'{name}_upsample')(x)  #2x2 transposed conv, 2x upsampling
    x = layers.Concatenate(name=f'{name}_concat')([x, skip])
        # concatenate with skip connection from encoder
    x = layers.Conv2D(filters, 3, padding='same', name=f'{name}_conv1')(x)
    x = layers.BatchNormalization(name=f'{name}_bn1')(x)
    x = layers.ReLU(name=f'{name}_relu1')(x)
    
    x = layers.Conv2D(filters, 3, padding='same', name=f'{name}_conv2')(x)
    x = layers.BatchNormalization(name=f'{name}_bn2')(x)
    x = layers.ReLU(name=f'{name}_relu2')(x)
    
    return x


In [None]:
# build complete U-Net model
def build_unet(input_shape=(384, 384, 3), num_classes=NUM_CLASSES):
    
#   input_shape: Input image shape (H, W, C)
#   num_classes: Number of output classes
    
    inputs = layers.Input(shape=input_shape, name='input')
    
    # Encoder
    s1, p1 = encoder_block(inputs, 32, 'enc1', pool=True)    # 384 -> 192
    s2, p2 = encoder_block(p1, 64, 'enc2', pool=True)       # 192 -> 96
    s3, p3 = encoder_block(p2, 128, 'enc3', pool=True)       # 96 -> 48
    s4, p4 = encoder_block(p3, 256, 'enc4', pool=True)       # 48 -> 24
    
    # Bottleneck
    b = encoder_block(p4, 512, 'bottleneck', pool=False)    # 32
    
    # Decoder
    d4 = decoder_block(b, s4, 256, 'dec4')                   # 24 -> 48
    d3 = decoder_block(d4, s3, 128, 'dec3')                  # 48 -> 96
    d2 = decoder_block(d3, s2, 64, 'dec2')                  # 96 -> 192
    d1 = decoder_block(d2, s1, 32, 'dec1')                   # 192 -> 384
    
    # Output
    outputs = layers.Conv2D(num_classes, 1, activation=None, name='output')(d1)
    
    model = Model(inputs, outputs, name='unet')
    return model

# Create model
model = build_unet(
    input_shape=(*CONFIG['image_size'], 3),
    num_classes=NUM_CLASSES
)

In [11]:
# Cell 10: Model Summary
model.summary()

In [None]:
# loss functions - categorical cross-entropy
def categorical_crossentropy_loss(y_true, y_pred, ignore_label=255):

    # create mask for valid pixels
    mask = tf.not_equal(y_true, ignore_label)
    mask = tf.cast(mask, tf.float32)
    
    # compute loss
    loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss = loss_fn(y_true, y_pred)
    
    # apply mask and normalize
    loss = loss * mask
    return tf.reduce_sum(loss) / (tf.reduce_sum(mask) + 1e-7)


Dice loss - helps with class imbalance (especially for rare classes)

In [None]:
# loss functions - dice loss
def dice_loss(y_true, y_pred, ignore_label=255, smooth=1e-6):
    
    y_true = tf.cast(y_true, tf.int32)
    # convert predictions to probabilities
    y_pred = tf.nn.softmax(y_pred, axis=-1)
    
    # one-hot encode ground truth
    y_true_one_hot = tf.one_hot(y_true, depth=NUM_CLASSES)
    
    # create mask for valid pixels
    mask = tf.not_equal(y_true, ignore_label)
    mask = tf.cast(mask, tf.float32)
    mask = tf.expand_dims(mask, axis=-1)
    
    # apply mask
    y_true_one_hot = y_true_one_hot * mask
    y_pred = y_pred * mask
    
    # compute dice coefficient per class
    intersection = tf.reduce_sum(y_true_one_hot * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true_one_hot, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    
    dice = (2.0 * intersection + smooth) / (union + smooth)
    
    # return (1 - mean Dice as loss)
    return 1.0 - tf.reduce_mean(dice)



In [None]:
# Combined Loss Function
def get_combined_loss(ce_weight=0.5, dice_weight=0.5, ignore_label=255):
    def combined_loss(y_true, y_pred):
        # Create mask for valid pixels
        mask = tf.cast(tf.not_equal(y_true, ignore_label), tf.float32)
        
        # Cross-entropy loss, from_logits=True because model outputs logits
        cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        loss_ce = cce(y_true, y_pred, sample_weight=mask)
        
        # Dice loss
        loss_dice = dice_loss(y_true, y_pred, ignore_label)
        
        # Combined loss
        return ce_weight * loss_ce + dice_weight * loss_dice
        
    return combined_loss

IoU = (Predicted ∩ Ground Truth) / (Predicted ∪ Ground Truth)
This metric gives a reliable measure of segmentation quality during training


In [None]:
# metrics - mean IoU
class MeanIoU(keras.metrics.Metric):
    # Create persistent confusion matrix 
    def __init__(self, num_classes, ignore_label=255, name='mean_iou', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.total_cm = self.add_weight(
            name='total_confusion_matrix',
            shape=(num_classes, num_classes),
            initializer='zeros',
            dtype=tf.float32
        )
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        # xonvert logits to class predictions
        y_pred = tf.argmax(y_pred, axis=-1)
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)
        #  remove ignored pixels (like borders or unknown regions)
        # mask valid pixels
        mask = tf.not_equal(y_true, self.ignore_label)
        y_true = tf.boolean_mask(y_true, mask)
        y_pred = tf.boolean_mask(y_pred, mask)
        
        # compute confusion matrix
        cm = tf.math.confusion_matrix(
            y_true, y_pred, 
            num_classes=self.num_classes, 
            dtype=tf.float32
        )
        self.total_cm.assign_add(cm)
    
    def result(self):
    # calculate IoU per class from confusion matrix

        sum_over_row = tf.reduce_sum(self.total_cm, axis=0)
        sum_over_col = tf.reduce_sum(self.total_cm, axis=1)
        diag = tf.linalg.diag_part(self.total_cm)
        
        denominator = sum_over_row + sum_over_col - diag
        iou = tf.divide(diag, denominator + 1e-7)
        
        # average IoU over classes with predictions
        valid_iou = tf.boolean_mask(iou, denominator > 0)
        return tf.reduce_mean(valid_iou)
    
    def reset_state(self):
        self.total_cm.assign(tf.zeros_like(self.total_cm))


In the task case Mean IoU better than dice score, but its ok to use it as an additional metric

In [None]:
# Metrics - per class dice score
class DiceScore(keras.metrics.Metric):
    
    def __init__(self, num_classes, ignore_label=255, name='dice_score', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.dice_sum = self.add_weight(name='dice_sum', initializer='zeros')
        self.count = self.add_weight(name='count', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        # convert logits to class predictions
        y_pred = tf.nn.softmax(y_pred, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)
        
        # mask valid pixels
        mask = tf.not_equal(y_true, self.ignore_label)
        
        dice_scores = []
        for class_idx in range(self.num_classes):
            # create binary masks for this class
            y_true_class = tf.equal(y_true, class_idx)# true pixels of this class
            y_pred_class = tf.equal(y_pred, class_idx) # predicted pixels
            
            y_true_class = tf.boolean_mask(y_true_class, mask)
            y_pred_class = tf.boolean_mask(y_pred_class, mask)
            
            y_true_class = tf.cast(y_true_class, tf.float32)
            y_pred_class = tf.cast(y_pred_class, tf.float32)
            # dice formula: 2*intersection / (sum1 + sum2)
            intersection = tf.reduce_sum(y_true_class * y_pred_class)
            union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)
            
            dice = (2.0 * intersection + 1e-7) / (union + 1e-7)
            dice_scores.append(dice)
        
        mean_dice = tf.reduce_mean(dice_scores)
        self.dice_sum.assign_add(mean_dice)
        self.count.assign_add(1.0)
    
    def result(self):
        # rturn average dice across all batches
        return self.dice_sum / (self.count + 1e-7)
    
    def reset_state(self):
        self.dice_sum.assign(0.0)
        self.count.assign(0.0)


In [None]:
# compile model
# learning rate schedule
lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=CONFIG['learning_rate'],
    decay_steps=CONFIG['epochs'] * len(train_gen),
    alpha=0.1
)

optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)

# compile with combined loss
model.compile(
    optimizer=optimizer,
    loss=get_combined_loss(ce_weight=0.5, dice_weight=0.5),
    metrics=[
        MeanIoU(num_classes=NUM_CLASSES, ignore_label=CONFIG['ignore_label']),
        DiceScore(num_classes=NUM_CLASSES, ignore_label=CONFIG['ignore_label']),
        keras.metrics.SparseCategoricalAccuracy(name='pixel_accuracy')
    ]
)

print("✅ Model compiled with combined loss (CE + Dice)")

In [None]:
# callbacks
os.makedirs('./checkpoints_unet', exist_ok=True)

callbacks = [
    # saving the best model based on validation mIoU
    keras.callbacks.ModelCheckpoint(
        filepath='./checkpoints_unet/best_model.keras',
        monitor='val_mean_iou',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    
    # save model every epoch
    keras.callbacks.ModelCheckpoint(
        filepath='./checkpoints_unet/model_epoch_{epoch:02d}.keras',
        save_freq='epoch',
        verbose=0
    ),
    
    # reduce learning rate on plateau
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=CONFIG['reduce_lr_patience'],
        min_lr=1e-7,
        verbose=1
    ),
    
    # early stopping
    keras.callbacks.EarlyStopping(
        monitor='val_mean_iou',
        patience=CONFIG['early_stopping_patience'],
        mode='max',
        restore_best_weights=True,
        verbose=1
    )
]

In [None]:
# training
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=CONFIG['epochs'],
    callbacks=callbacks,
    verbose=1
)


In [None]:
# Cell 19: Save Final Model
model.save('./checkpoints_unet/final_model.keras')

In [None]:
def visualize_predictions(model, generator, num_samples=5, start_index=None, random=False, save_path='./checkpoints_unet/predictions.png'):
   # data visualisation
    # getting data from generator
    batch_data = generator[17] #number of batch
    
    if isinstance(batch_data, tuple) and len(batch_data) == 2:
        images, masks = batch_data
    else:
        images = batch_data[0] if isinstance(batch_data, list) else batch_data
        masks = None

    if start_index is not None:
        start_index = min(start_index, len(images) - num_samples)
        start_index = max(0, start_index)
        indices = list(range(start_index, min(start_index + num_samples, len(images))))
    else:
        # first few samples
        indices = list(range(min(num_samples, len(images))))
        
    if isinstance(indices, np.ndarray):
        selected_images = images[indices]
        if masks is not None:
            selected_masks = masks[indices]
        else:
            selected_masks = None
    else:
        selected_images = np.array([images[i] for i in indices])
        if masks is not None:
            selected_masks = np.array([masks[i] for i in indices])
        else:
            selected_masks = None
    
    try:
        predictions = model.predict(selected_images, verbose=1, batch_size=len(selected_images))
        pred_masks = np.argmax(predictions, axis=-1)
    except Exception as e:
        print(f"error {e}")
        return
    
    num_plots = len(indices)
    if num_plots == 1:
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes = axes.reshape(1, -1)
    else:
        fig, axes = plt.subplots(num_plots, 4, figsize=(20, 5 * num_plots))
    
    for i in range(num_plots):
        # Original image
        axes[i, 0].imshow(selected_images[i])
        axes[i, 0].set_title(f'Image #{indices[i]}', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        # Ground truth mask 
        if selected_masks is not None:
            axes[i, 1].imshow(selected_masks[i], cmap='tab20', vmin=0, vmax=NUM_CLASSES-1)
            axes[i, 1].set_title('Ground Truth', fontsize=12, fontweight='bold')
        else:
            axes[i, 1].imshow(np.zeros_like(pred_masks[i]), cmap='tab20', vmin=0, vmax=NUM_CLASSES-1)
            axes[i, 1].set_title('No Ground Truth', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        # Predicted mask
        axes[i, 2].imshow(pred_masks[i], cmap='tab20', vmin=0, vmax=NUM_CLASSES-1)
        axes[i, 2].set_title('Prediction', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
        
        # Overlay
        overlay = selected_images[i].copy()
        mask_colored = plt.cm.tab20(pred_masks[i] / float(NUM_CLASSES))[:, :, :3]
        
        if selected_masks is not None:
            valid_mask = selected_masks[i] != 255
            overlay[valid_mask] = overlay[valid_mask] * 0.6 + mask_colored[valid_mask] * 0.4
        else:
            # if there is no mask
            overlay = overlay * 0.6 + mask_colored * 0.4
        
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay', fontsize=12, fontweight='bold')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
visualize_predictions(model, val_gen, num_samples=4, start_index=0) #call the func to visualise
