In [None]:
import os
import random
from glob import glob

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_datasets as tfds
from sklearn.utils.class_weight import compute_class_weight

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Activation

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import imageio

In [None]:
seed = 40
tf.random.set_seed(seed=40)

In [None]:
IMG_HEIGHT = 512
IMG_WIDTH = 512
BATCH_SIZE = 16

val_split = 0.2
n_classes = 2  # change this to 4 when you run 4 classes
class_names = ["soil", "residue"] # specify the classes when you run for 4 classes
#class_names = ["residue_sunlit", "residue_shaded", "background_sunlit", "background_shaded"]

# Evaluation Metrics

In [None]:
def compute_confusion_matrix(y_true, y_pred, num_classes=n_classes):
    """Compute confusion matrix for multi-class segmentation"""
    # Flatten the arrays (each pixel is treated as a classification label)
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()

    # Compute confusion matrix
    cm = confusion_matrix(y_true_flat, y_pred_flat, labels=np.arange(num_classes))

    return cm

In [None]:
def compute_metrics(cm):
    tp = np.diag(cm)  # True positives
    fp = np.sum(cm, axis=0) - tp  # False positives
    fn = np.sum(cm, axis=1) - tp  # False negatives

    precision = tp / (tp + fp + 1e-6)  # Avoid division by zero
    recall = tp / (tp + fn + 1e-6)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)

    return precision, recall, f1_score

In [None]:
def compute_iou(cm):
    """Compute IoU for each class"""
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    iou = tp / (tp + fp + fn + 1e-6)  # Avoid division by zero

    mean_iou = np.mean(iou)  # Mean IoU across all classes
    return iou, mean_iou

# Preparing dataset

In [None]:
def load_and_preprocess_image(image_path):
    """Loads and preprocesses an image from the given path."""
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    return tf.cast(image, tf.int64)  # (IMG_HEIGHT, IMG_WIDTH, 3)
    
def load_and_preprocess_mask(mask_path):
    """Loads and preprocesses a mask from the given path."""
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, [IMG_HEIGHT, IMG_WIDTH], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    mask = tf.one_hot(tf.cast(mask, tf.int32), depth=n_classes)
    return tf.squeeze(mask, -2) # (IMG_HEIGHT, IMG_WIDTH, n_classes)

def extract_image_name(image_path):
    """Extracts the filename without extension from the image path."""
    filename = tf.strings.split(image_path, os.sep)[-1]  # Extract filename
    return tf.strings.split(filename, '.')[0]  # Remove extension
    
def parse_image_and_mask(image_path, mask_path):
    """Parses and processes an image and its corresponding mask."""
    image = load_and_preprocess_image(image_path)
    mask = load_and_preprocess_mask(mask_path)
    image_name = extract_image_name(image_path)
    return image, mask, image_name

def create_dataset(image_paths, mask_paths):
    """Creates a TensorFlow dataset from image and mask paths."""
    path_dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    return path_dataset.map(parse_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)

# Specify the base directory
root = "/kaggle/input/crop-residue-new-jpg/Training/"

# Specify the paths for input images and masks
image_paths = sorted(glob(os.path.join(root, "input/*")))
mask_paths = sorted(glob(os.path.join(root, "output/converted/jpg_files/*")))

# Create the dataset
dataset = create_dataset(image_paths, mask_paths)

In [None]:
def save_predicted_image(image_name, pred):
    """
    Save Predicted Images

    Args:
        image_name (tf.Tensor): Tensor containing the image filename.
        pred (numpy array, optional): The predicted mask (one-hot encoded). Defaults to None.
    """
    image_name_str = image_name.numpy().decode('utf-8')
    # Convert predictions to class labels
    pred_mask = np.argmax(pred, axis=-1) + 1  # Shift classes from [0, n_classes-1] to [1, n_classes]
    
    # Save predicted mask as an image
    pred_img_path = f"/kaggle/working/{image_name_str}.png"
    imageio.imwrite(pred_img_path, pred_mask.astype(np.uint8))

In [None]:
def plot_example(image, mask, image_name, pred=None):
    """
    Plots the input image, its corresponding masks, and optionally the predicted masks.

    Args:
        image (numpy array or tensor): The input image.
        mask (numpy array or tensor): The ground truth segmentation mask (one-hot encoded).
        image_name (tf.Tensor): Tensor containing the image filename.
        pred (numpy array, optional): The predicted mask (one-hot encoded). Defaults to None.
    """

    image_name_str = image_name.numpy().decode('utf-8')

    nrows = 2 if pred is None else 3
    ncols = n_classes
    
    plt.figure(figsize=(20, 10))
    
    # Display input image
    plt.subplot(nrows, ncols, 1)
    plt.imshow(image)
    plt.title(f"Image: {image_name_str}")
    plt.axis("off")

    # Display ground truth masks
    for i in range(n_classes):
        plt.subplot(nrows, ncols, i + n_classes + 1)
        plt.imshow(mask[:, :, i])
        plt.title(f"Mask: {class_names[i]}")
        plt.axis("off")
        
    if pred is not None:
        # Display predicted masks
        for i in range(n_classes):
            plt.subplot(nrows, ncols, i + n_classes * 2 + 1)
            plt.imshow(pred[:, :, i])
            plt.title(f"Prediction: {class_names[i]}")
            plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
image, mask, image_name = next(iter(dataset))
plot_example(image, mask, image_name)

In [None]:
# Determine dataset size and validation split
dataset_size = len(dataset)
val_size = int(val_split * dataset_size)

print(f"Training samples: {dataset_size - val_size}")
print(f"Validation samples: {val_size}")

# Shuffle the dataset
shuffled_dataset = dataset.shuffle(buffer_size=1000, seed=seed)

# Split into training and validation sets
train_ds = shuffled_dataset.skip(val_size)
val_ds = shuffled_dataset.take(val_size)

# Apply dataset optimizations
train_ds = (train_ds
            .shuffle(buffer_size=1000, seed=seed)
            .cache()
            .batch(BATCH_SIZE)
            .prefetch(buffer_size=tf.data.AUTOTUNE))

val_ds = (val_ds
          .batch(BATCH_SIZE)
          .cache()
          .prefetch(buffer_size=tf.data.AUTOTUNE))

# Set class weights

In [None]:
# Filter out masks from training data
train_masks = tfds.as_numpy(train_ds.map(lambda img, mask, image_name: mask))

# Compute class weights for each batch
class_weights_list = []

for i, batch in enumerate(train_masks):
    print(f"Processing batch {i + 1}/{len(train_masks)}")

    batch = tf.math.argmax(batch, axis=-1)
    flattened_labels = tf.cast(tf.reshape(batch, -1), tf.int32).numpy()  # Flatten to 1D array

    # Compute class weights (balanced based on pixel frequency)
    class_weights = compute_class_weight(class_weight="balanced", classes=list(range(n_classes)), y=flattened_labels)
    class_weights_list.append(class_weights)

# Average class weights across all batches
class_weights = np.mean(class_weights_list, axis=0)

# Print final class weights
print(f"Class weights: {class_weights}")
print({class_names[i]: class_weights[i] for i in range(n_classes)})

# U-Net Model

In [None]:
def conv_block(x, filters, dropout_rate, l2_reg):
    """Convolutional block: Conv -> BatchNorm -> ReLU -> Dropout -> Conv -> BatchNorm -> ReLU"""

    x = Conv2D(filters, (3, 3), padding='same', use_bias=False,
               kernel_regularizer=tf.keras.regularizers.L2(l2_reg))(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    if dropout_rate > 0:
        x = Dropout(dropout_rate)(x)
    x = Conv2D(filters, (3, 3), padding='same', use_bias=False,
           kernel_regularizer=tf.keras.regularizers.L2(l2_reg))(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

def encoder_block(x, filters, dropout_rate, l2_reg):
    """Encoder block with convolution and max pooling"""
    x = conv_block(x, filters, dropout_rate, l2_reg)
    p = MaxPooling2D((2, 2))(x)
    return x, p
    
def decoder_block(x, skip_connection, filters, dropout_rate, l2_reg):
    """Decoder block with upsampling and skip connection"""
    x = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same',
                        kernel_regularizer=tf.keras.regularizers.L2(l2_reg))(x)
    x = concatenate([x, skip_connection])
    x = conv_block(x, filters, dropout_rate, l2_reg)
    return x

def multi_unet_model(n_classes, img_height, img_width, img_channels):
    inputs = Input((img_height, img_width, img_channels))
    s = inputs
    
    # Contraction path, encoder
    c1, p1 = encoder_block(inputs, filters=16, dropout_rate=0, l2_reg=0.0001)
    c2, p2 = encoder_block(p1, filters=32, dropout_rate=0, l2_reg=0.0001)
    c3, p3 = encoder_block(p2, filters=64, dropout_rate=0, l2_reg=0.0001)
    c4, p4 = encoder_block(p3, filters=128, dropout_rate=0, l2_reg=0.001)
    c5, p5 = encoder_block(p4, filters=256, dropout_rate=0, l2_reg=0.001)
    c6, p6 = encoder_block(p5, filters=512, dropout_rate=0, l2_reg=0.01)

    # Bottleneck
    bridge = conv_block(p6, filters=1024, dropout_rate=0, l2_reg=0.01)
    
    # Expansive path, decoder
    u6 = decoder_block(bridge, c6, filters=512, dropout_rate=0, l2_reg=0.01)
    u5 = decoder_block(u6, c5, filters=256, dropout_rate=0, l2_reg=0.001)
    u4 = decoder_block(u5, c4, filters=128, dropout_rate=0, l2_reg=0.001)
    u3 = decoder_block(u4, c3, filters=64, dropout_rate=0, l2_reg=0.0001)
    u2 = decoder_block(u3, c2, filters=32, dropout_rate=0, l2_reg=0.0001)
    u1 = decoder_block(u2, c1, filters=16, dropout_rate=0, l2_reg=0.0001)

    outputs = Conv2D(n_classes, (1, 1), activation='softmax')(u1)

    model = Model(inputs=[inputs], outputs=[outputs])

    return model

In [None]:
unet_model = multi_unet_model(n_classes=n_classes, img_height=IMG_HEIGHT, img_width=IMG_WIDTH, img_channels=3)

print(f"Input shape: {unet_model.input_shape}")
print(f"Output shape: {unet_model.output_shape}")
print(f"Trainable params: {np.sum([np.prod(v.get_shape()) for v in unet_model.trainable_variables])}")

# Training

In [None]:
# Hyperparameters
EPOCHS = 50
INITIAL_LR = 1e-4
START_LR = 1e-2
END_LR = 1e-4
DECAY_RATE = 0.96 # Reduce learning rate by 4% every 100 steps
DECAY_STEPS = len(train_ds) * 400
CLIP_NORM = 1.0
CHECKPOINT_PATH = "training_2/model.h5"


# Model definition
model = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1./255),  # Normalize input
    unet_model,
])

# Learning rate scheduler (Exponential Decay)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=INITIAL_LR,
    decay_steps=DECAY_STEPS,
    decay_rate=DECAY_RATE,
    staircase=True
)

# lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
#     starter_learning_rate=START_LR ,
#     decay_steps=DECAY_STEPS,
#     end_learning_rate=END_LR,
#     power=0.5)

# Optimizer with gradient clipping
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipnorm=CLIP_NORM)

# Loss function (Categorical Focal Crossentropy)
loss_fn = tf.keras.losses.CategoricalFocalCrossentropy(alpha=class_weights, gamma=2)

# Model compilation
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

# Expected loss calculation (only valid without weight regularization)
expected_loss = -np.log(1 / n_classes)
print(f"Expected initial loss: {expected_loss}")

# Callbacks for training
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        CHECKPOINT_PATH, monitor='val_loss', save_best_only=True, save_freq='epoch'
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6, verbose=1
    )
]

# Prepare datasets (remove image names)
train_ds_for_fit = train_ds.map(lambda image, mask, image_name: (image, mask))  # Keep only (image, mask)
val_ds_for_fit = val_ds.map(lambda image, mask, image_name: (image, mask))

# Model training
history = model.fit(
    train_ds_for_fit,
    validation_data=val_ds_for_fit,
    epochs=EPOCHS,
    callbacks=callbacks,
)


# Accuracy and Loss plots

In [None]:
# Extract training history
history_data = history.history
epochs = np.arange(1, len(history_data["loss"]) + 1)

# Plot training and validation metrics
fig, axes = plt.subplots(1, 2, figsize=(20, 5))

# Loss plot
axes[0].plot(epochs, history_data["loss"], label="Training Loss")
axes[0].plot(epochs, history_data["val_loss"], label="Validation Loss")
axes[0].set_title("Categorical Focal Cross Entropy")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].grid(True)
axes[0].legend()

# Accuracy plot
axes[1].plot(epochs, history_data["accuracy"], label="Training Accuracy")
axes[1].plot(epochs, history_data["val_accuracy"], label="Validation Accuracy")
axes[1].set_title("Model Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].grid(True)
axes[1].legend()

plt.show()

# Evaluation

In [None]:
# Prepare datasets for evaluation (exclude image names)
train_ds_for_eval = train_ds.map(lambda image, mask, image_name: (image, mask))  # Keep only (image, mask)
val_ds_for_eval = val_ds.map(lambda image, mask, image_name: (image, mask))

train_loss = model.evaluate(train_ds_for_eval)[0]
val_loss = model.evaluate(val_ds_for_eval)[0]

# Display results
print(f"Final Training Loss: {train_loss:.4f}")
print(f"Final Validation Loss: {val_loss:.4f}")

In [None]:
y_true_list_train = []
y_pred_list_train = []
y_true_list_validation = []
y_pred_list_validation = []

In [None]:
# Test all training data
for imageb, maskb, image_nameb in train_ds:
    for i in range(len(imageb)):
        image, mask, image_name = imageb[i], maskb[i], image_nameb[i]

        # Print the shape of the image and mask
        print(f"Image {i} shape: {image.shape}")  # Should be (H, W, C) for RGB or (H, W) for grayscale
        print(f"Mask {i} shape: {mask.shape}")  # Should be (H, W) or (H, W, num_classes) if one-hot encoded

        pred = model.predict(image[tf.newaxis, ...])[0] # (height, width, n_classes) , in probabilities
        pred = tf.math.argmax(pred, axis=-1) # (height, width) , in classes [0, 1, 2, 3, 4,...]
        pred = tf.cast(tf.one_hot(pred, depth=n_classes), tf.int64) # (height, width, n_classes) , in one hot [0, 0, 0, 0, 1], [0, 1, 0, 0, 0] ...

        print(f"Pred {i} shape: {pred.shape}")  # Should be (H, W) or (H, W, num_classes) if one-hot encoded

        y_true_list_train.append(mask.numpy())  # Convert tensor to NumPy (H, W)
        y_pred_list_train.append(pred.numpy())  # Convert tensor to NumPy (H, W)

In [None]:
# Test all validation data
validation_image_names = []

for imageb, maskb, image_nameb in val_ds:
    for i in range(len(imageb)):
        image, mask, image_name = imageb[i], maskb[i], image_nameb[i]

        validation_image_names.append(image_name.numpy().decode('utf-8'))

        # Print the shape of the image and mask
        print(f"Image {i} shape: {image.shape}")  # Should be (H, W, C) for RGB or (H, W) for grayscale
        print(f"Mask {i} shape: {mask.shape}")  # Should be (H, W) or (H, W, num_classes) if one-hot encoded

        pred = model.predict(image[tf.newaxis, ...])[0] # (height, width, n_classes) , in probabilities
        pred = tf.math.argmax(pred, axis=-1) # (height, width) , in classes [0, 1, 2, 3, 4,...]
        pred = tf.cast(tf.one_hot(pred, depth=n_classes), tf.int64) # (height, width, n_classes) , in one hot [0, 0, 0, 0, 1], [0, 1, 0, 0, 0] ...

        print(f"Pred {i} shape: {pred.shape}")  # Should be (H, W) or (H, W, num_classes) if one-hot encoded

        y_true_list_validation.append(mask.numpy())  # Convert tensor to NumPy (H, W)
        y_pred_list_validation.append(pred.numpy())  # Convert tensor to NumPy (H, W)

        save_predicted_image(image_name,pred)


# Write validation image names to txt
with open("validation_image_names.txt", "w") as file:
    file.writelines(f"{item}\n" for item in validation_image_names)  # Adds newline after each item


In [None]:
# Convert to NumPy arrays
y_true_train = np.array(y_true_list_train)  # Shape: (num_images, H, W)
y_pred_train = np.array(y_pred_list_train)  # Shape: (num_images, H, W)
y_true_validation = np.array(y_true_list_validation)  # Shape: (num_images, H, W)
y_pred_validation = np.array(y_pred_list_validation)  # Shape: (num_images, H, W)
print("y_true_train total validation imgs:", len(y_true_list_train))  # Should be N * H * W
print("y_pred_train total validation imgs:", len(y_pred_list_train))  # Should be N * H * W
print("y_true_validation total validation imgs:", len(y_true_list_validation))  # Should be N * H * W
print("y_pred_validation total validation imgs:", len(y_pred_list_validation))  # Should be N * H * W

print("y_true_train total pixels:", y_true_train.size)  # Should be N * H * W
print("y_pred_train total pixels:", y_pred_train.size)  # Should be N * H * W
print("y_pred_validation total pixels:", y_true_validation.size)  # Should be N * H * W
print("y_pred_validation total pixels:", y_pred_validation.size)  # Should be N * H * W

In [None]:

y_true_train_flat = y_true_train.flatten()
y_pred_train_flat = y_pred_train.flatten()
y_true_validation_flat = y_true_validation.flatten()
y_pred_validation_flat = y_pred_validation.flatten()

# Compute confusion matrix for train
cm_train = confusion_matrix(y_true_train_flat, y_pred_train_flat, labels=np.arange(n_classes))
print("Confusion Matrix Training:\n", cm_train)

# Compute confusion matrix for train
cm_validation = confusion_matrix(y_true_validation_flat, y_pred_validation_flat, labels=np.arange(n_classes))
print("Confusion Matrix Validation:\n", cm_validation)

In [None]:
precision_train, recall_train, f1_score_train = compute_metrics(cm_train)
iou_train, mean_iou_train = compute_iou(cm_train)

print(f"Train Precision: {precision_train}")
print(f"Train Recall: {recall_train}")
print(f"Train F1 Score: {f1_score_train}")

print(f"Train IoU per class: {iou_train}")
print(f"Train Mean IoU: {mean_iou_train}")

precision_validation, recall_validation, f1_score_validation = compute_metrics(cm_validation)
iou_validation, mean_iou_validation = compute_iou(cm_validation)

print(f"Validation Precision: {precision_validation}")
print(f"Validation Recall: {recall_validation}")
print(f"Validation F1 Score: {f1_score_validation}")

print(f"Validation IoU per class: {iou_validation}")
print(f"Validation Mean IoU: {mean_iou_validation}")

In [None]:
# Compute accuracy
accuracy_train = accuracy_score(y_true_train_flat, y_pred_train_flat)
print(f"Train Accuracy: {accuracy_train:.4f}")

accuracy_validation = accuracy_score(y_true_validation_flat, y_pred_validation_flat)
print(f"Validation Accuracy: {accuracy_validation:.4f}")

# Another way to Compute precision, recall, F1-score for each class
precision_train_2 = precision_score(y_true_train_flat, y_pred_train_flat, average=None)  # Per class
recall_train_2 = recall_score(y_true_train_flat, y_pred_train_flat, average=None)  # Per class

precision_val_2 = precision_score(y_true_validation_flat, y_pred_validation_flat, average=None)  # Per class
recall_val_2 = recall_score(y_true_validation_flat, y_pred_validation_flat, average=None)  # Per class

# f1 = f1_score(y_true_flat, y_pred_flat, average=None)  # Per class
print(f"Train Precision method2: {precision_train_2}")
print(f"Train Recall method2: {recall_train_2}")

print(f"Validation Precision method2: {precision_val_2}")
print(f"Validation Recall method2: {recall_val_2}")
# print(f"F1-score: {f1}")

In [None]:
# Test on some validation data
# Use save_predicted_image function to save the predicted images

imageb, maskb, image_nameb = next(iter(val_ds))

for i in range(len(imageb)):
    image, mask, image_name = imageb[i], maskb[i], image_nameb[i]

    pred = model.predict(image[tf.newaxis, ...])[0] # (height, width, n_classes) , in probabilities
    pred = tf.math.argmax(pred, axis=-1) # (height, width) , in classes [0, 1, 2, 3, 4,...]
    pred = tf.cast(tf.one_hot(pred, depth=n_classes), tf.int64) # (height, width, n_classes) , in one hot [0, 0, 0, 0, 1], [0, 1, 0, 0, 0] ...

    plot_example(image, mask,image_name, pred)