In [None]:
# Importing the necessary libraries

import tensorflow as tf
# import albumentations as albu # Not needed as augmentations are removed
import numpy as np
import gc
# import pickle # Not used in the provided snippet
import matplotlib.pyplot as plt
from keras.callbacks import CSVLogger
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score, precision_score, recall_score, accuracy_score, f1_score
from ModelArchitecture.DiceLoss import dice_metric_loss
from ModelArchitecture import DUCK_Net
from ImageLoader import ImageLoader2D
from functools import partial # Added for tf.data.Dataset.from_generator
import os # Added for path joining

In [None]:
# Checking the number of GPUs available

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Setting the model parameters

IMG_HEIGHT = 608
IMG_WIDTH = 448
BATCH_SIZE = 4 # Adjust based on your GPU memory

# IMPORTANT: Set this path to the root directory containing your 'train', 'masks', and 'test' subfolders
DATA_FOLDER_PATH = "/workspaces/DUCK-Net/sample_data/" # Replace with your actual data path

dataset_name_suffix = f"custom_h{IMG_HEIGHT}_w{IMG_WIDTH}"
learning_rate = 1e-4
seed_value = 58800 # For reproducibility of splits
filters = 17 # Number of filters, the paper presents the results with 17 and 34
optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

ct = datetime.now().strftime("%Y%m%d_%H%M%S")

model_type = "DuckNet"

# Create directories if they don't exist
os.makedirs('ProgressFull', exist_ok=True)
os.makedirs(f'ModelSaveTensorFlow/{dataset_name_suffix}', exist_ok=True)

progress_path = f'ProgressFull/{dataset_name_suffix}_progress_csv_{model_type}_filters_{str(filters)}_{ct}.csv'
progressfull_path = f'ProgressFull/{dataset_name_suffix}_progress_{model_type}_filters_{str(filters)}_{ct}.txt'
# plot_path = f'ProgressFull/{dataset_name_suffix}_progress_plot_{model_type}_filters_{str(filters)}_{ct}.png' # Plotting code not shown, can be added if needed
model_path = f'ModelSaveTensorFlow/{dataset_name_suffix}/{model_type}_filters_{str(filters)}_{ct}'

EPOCHS = 100 # Adjust as needed, original was 600
min_loss_for_saving = 0.2 # Initial high value, will be updated

In [None]:
# Loading the data paths

print(f"Loading data from: {DATA_FOLDER_PATH}")
# ImageLoader2D.folder_path = DATA_FOLDER_PATH # Set if ImageLoader2D uses a global var, otherwise pass as arg
all_train_img_files, all_train_mask_files, final_test_img_files = ImageLoader2D.get_image_paths(DATA_FOLDER_PATH)

print(f"Found {len(all_train_img_files)} training images and {len(all_train_mask_files)} training masks.")
print(f"Found {len(final_test_img_files)} final test images.")

if not all_train_img_files or not all_train_mask_files:
    raise ValueError("No training images or masks found. Check DATA_FOLDER_PATH and data structure.")

# Ensure the number of images and masks match before splitting
if len(all_train_img_files) != len(all_train_mask_files):
    raise ValueError(f"Mismatch between number of training images ({len(all_train_img_files)}) and masks ({len(all_train_mask_files)}). Please check your dataset and ImageLoader2D.get_image_paths logic.")


In [None]:
# Splitting the data paths for training, validation, and a hold-out test set (from training data)

# Split full training data into training + validation set and a hold-out test set
# Using 80% for train+val, 20% for hold-out test from the original training data
train_val_img_paths, hold_out_test_img_paths, train_val_mask_paths, hold_out_test_mask_paths = train_test_split(
    all_train_img_files, all_train_mask_files, test_size=0.2, random_state=seed_value, shuffle=True
)

# Split training + validation set into actual training and validation sets
# Using ~80% of train_val for training, ~20% for validation (which is 16% of total original training data)
train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = train_test_split(
    train_val_img_paths, train_val_mask_paths, test_size=0.2, random_state=seed_value, shuffle=True # 0.2 of 0.8 is 0.16
)

print(f"Training images: {len(train_img_paths)}")
print(f"Validation images: {len(val_img_paths)}")
print(f"Hold-out test images (from training data): {len(hold_out_test_img_paths)}")
print(f"Final test images (unlabeled): {len(final_test_img_files)}")


# Define output signatures for tf.data.Dataset
output_signature_train = (
    tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),
    tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, 1), dtype=tf.uint8)
)
output_signature_test = tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32)

# Create tf.data.Dataset objects
# Partial function to pass fixed arguments to the generator
train_generator = partial(ImageLoader2D.tf_dataset_generator, 
                          img_height=IMG_HEIGHT, img_width=IMG_WIDTH, is_test_set=False)
test_generator = partial(ImageLoader2D.tf_dataset_generator, 
                         img_height=IMG_HEIGHT, img_width=IMG_WIDTH, is_test_set=True)

if train_img_paths:
    train_ds = tf.data.Dataset.from_generator(
        lambda: train_generator(image_paths=train_img_paths, mask_paths=train_mask_paths),
        output_signature=output_signature_train
    ).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
else:
    train_ds = None
    print("Warning: No training data after split.")

if val_img_paths:
    val_ds = tf.data.Dataset.from_generator(
        lambda: train_generator(image_paths=val_img_paths, mask_paths=val_mask_paths),
        output_signature=output_signature_train
    ).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
else:
    val_ds = None
    print("Warning: No validation data after split.")

if hold_out_test_img_paths:
    hold_out_test_ds = tf.data.Dataset.from_generator(
        lambda: train_generator(image_paths=hold_out_test_img_paths, mask_paths=hold_out_test_mask_paths),
        output_signature=output_signature_train
    ).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
else:
    hold_out_test_ds = None
    print("Warning: No hold-out test data after split.")

if final_test_img_files:
    final_test_ds = tf.data.Dataset.from_generator(
        lambda: test_generator(image_paths=final_test_img_files, mask_paths=None), # mask_paths is None for test_generator
        output_signature=output_signature_test
    ).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
else:
    final_test_ds = None
    print("Warning: No final (unlabeled) test data found.")

if train_ds is None:
    raise ValueError("Training dataset is empty. Cannot proceed.")


In [None]:
# Augmentations are removed as per user request.
# This cell can be deleted or left empty.
# Original content:
# aug_train = albu.Compose([
#     albu.HorizontalFlip(),
#     albu.VerticalFlip(),
#     albu.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),
#     albu.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True),
# ])

# def augment_images():
#     x_train_out = []
#     y_train_out = []

#     for i in range (len(x_train)):
#         ug = aug_train(image=x_train[i], mask=y_train[i])
#         x_train_out.append(ug['image'])  
#         y_train_out.append(ug['mask'])

#     return np.array(x_train_out), np.array(y_train_out)
print("Augmentation cell skipped as per requirement.")

In [None]:
# Creating the model

model = DUCK_Net.create_model(img_height=IMG_HEIGHT, img_width=IMG_WIDTH, input_chanels=3, out_classes=1, starting_filters=filters)
model.summary() # Display model structure

In [None]:
# Compiling the model

model.compile(optimizer=optimizer, loss=dice_metric_loss, metrics=[dice_metric_loss]) # Added dice_metric_loss also as a metric for easier logging

In [None]:
# Training the model

# step = 0 # step variable not used in the new loop structure

csv_logger = CSVLogger(progress_path, append=True, separator=';')
history_list = [] # To store history from each epoch training

print(f"Starting training for {EPOCHS} epochs...")
for epoch in range(EPOCHS):
    print(f'\nTraining, epoch {epoch+1}/{EPOCHS}')
    print('Learning Rate: ' + str(tf.keras.backend.get_value(model.optimizer.learning_rate))) # Get current LR

    # No augmentation needed, fit directly on train_ds
    history = model.fit(
        train_ds, 
        epochs=1, # Training one epoch at a time in this loop to allow custom actions between epochs
        validation_data=val_ds, 
        callbacks=[csv_logger], # CSVLogger will append, ensure it's initialized outside the loop
        verbose=1
    )
    history_list.append(history.history)

    # Evaluate on validation set
    if val_ds:
        val_results = model.evaluate(val_ds, verbose=0)
        loss_valid = val_results[0] # Assuming loss is the first item
        print(f"Validation Loss: {loss_valid:.4f}")
        if len(val_results) > 1:
             print(f"Validation Dice Metric: {val_results[1]:.4f}") # Assuming dice_metric_loss is the second item (first metric)
    else:
        loss_valid = float('inf') # No validation set, cannot determine validation loss
        print("No validation set to evaluate.")
        
    # Evaluate on hold-out test set (from training data)
    if hold_out_test_ds:
        test_results = model.evaluate(hold_out_test_ds, verbose=0)
        loss_test = test_results[0]
        print(f"Hold-out Test Loss: {loss_test:.4f}")
        if len(test_results) > 1:
            print(f"Hold-out Test Dice Metric: {test_results[1]:.4f}")

    else:
        loss_test = float('inf')
        print("No hold-out test set to evaluate.")
        
    with open(progressfull_path, 'a') as f:
        f.write(f'epoch: {epoch}\nval_loss: {loss_valid:.4f}\ntest_loss: {loss_test:.4f}\n\n')
    
    if val_ds and loss_valid < min_loss_for_saving: # Save based on validation loss
        min_loss_for_saving = loss_valid
        print(f"Improved validation loss to {min_loss_for_saving:.4f}. Saving model to {model_path}")
        model.save(model_path)
    elif not val_ds and epoch % 10 == 0: # If no validation set, save periodically
        print(f"No validation set. Saving model at epoch {epoch} to {model_path}")
        model.save(model_path)
        
    gc.collect() # Garbage collection

In [None]:
# Computing the metrics and saving the results

print(f"Loading the best model from: {model_path}")
# Ensure custom objects are passed if model was saved with custom loss/metrics
model = tf.keras.models.load_model(model_path, custom_objects={'dice_metric_loss': dice_metric_loss})

# Helper function to get all predictions and labels from a tf.data.Dataset
def get_all_preds_labels(dataset, model_to_eval):
    all_labels_list = []
    all_preds_list = []
    if dataset is None:
        return np.array([]), np.array([])
        
    for images, labels in dataset: # Assumes dataset yields (images, labels)
        preds = model_to_eval.predict_on_batch(images)
        all_labels_list.append(labels.numpy())
        all_preds_list.append(preds)
    
    if not all_labels_list: # Handle empty dataset case
        return np.array([]), np.array([])

    return np.concatenate(all_labels_list, axis=0), np.concatenate(all_preds_list, axis=0)

print("Evaluating on training data...")
y_train_np, pred_train_np = get_all_preds_labels(train_ds, model)
print("Evaluating on validation data...")
y_valid_np, pred_valid_np = get_all_preds_labels(val_ds, model)
print("Evaluating on hold-out test data...")
y_test_np, pred_test_np = get_all_preds_labels(hold_out_test_ds, model)

print("Predictions done. Computing metrics...")

# Ensure there's data to compute metrics on
if y_train_np.size > 0:
    dice_train = f1_score(np.ndarray.flatten(y_train_np.astype(bool)), np.ndarray.flatten(pred_train_np > 0.5))
    miou_train = jaccard_score(np.ndarray.flatten(y_train_np.astype(bool)), np.ndarray.flatten(pred_train_np > 0.5))
    precision_train = precision_score(np.ndarray.flatten(y_train_np.astype(bool)), np.ndarray.flatten(pred_train_np > 0.5))
    recall_train = recall_score(np.ndarray.flatten(y_train_np.astype(bool)), np.ndarray.flatten(pred_train_np > 0.5))
    accuracy_train = accuracy_score(np.ndarray.flatten(y_train_np.astype(bool)), np.ndarray.flatten(pred_train_np > 0.5))
else:
    dice_train, miou_train, precision_train, recall_train, accuracy_train = [0]*5
    print("Warning: Training data is empty, metrics set to 0.")

if y_valid_np.size > 0:
    dice_valid = f1_score(np.ndarray.flatten(y_valid_np.astype(bool)), np.ndarray.flatten(pred_valid_np > 0.5))
    miou_valid = jaccard_score(np.ndarray.flatten(y_valid_np.astype(bool)), np.ndarray.flatten(pred_valid_np > 0.5))
    precision_valid = precision_score(np.ndarray.flatten(y_valid_np.astype(bool)), np.ndarray.flatten(pred_valid_np > 0.5))
    recall_valid = recall_score(np.ndarray.flatten(y_valid_np.astype(bool)), np.ndarray.flatten(pred_valid_np > 0.5))
    accuracy_valid = accuracy_score(np.ndarray.flatten(y_valid_np.astype(bool)), np.ndarray.flatten(pred_valid_np > 0.5))
else:
    dice_valid, miou_valid, precision_valid, recall_valid, accuracy_valid = [0]*5
    print("Warning: Validation data is empty, metrics set to 0.")

if y_test_np.size > 0:
    dice_test = f1_score(np.ndarray.flatten(y_test_np.astype(bool)), np.ndarray.flatten(pred_test_np > 0.5))
    miou_test = jaccard_score(np.ndarray.flatten(y_test_np.astype(bool)), np.ndarray.flatten(pred_test_np > 0.5))
    precision_test = precision_score(np.ndarray.flatten(y_test_np.astype(bool)), np.ndarray.flatten(pred_test_np > 0.5))
    recall_test = recall_score(np.ndarray.flatten(y_test_np.astype(bool)), np.ndarray.flatten(pred_test_np > 0.5))
    accuracy_test = accuracy_score(np.ndarray.flatten(y_test_np.astype(bool)), np.ndarray.flatten(pred_test_np > 0.5))
else:
    dice_test, miou_test, precision_test, recall_test, accuracy_test = [0]*5
    print("Warning: Hold-out test data is empty, metrics set to 0.")


print("Metrics computation finished.")

final_file = f'results_{model_type}_{str(filters)}_{dataset_name_suffix}_{ct}.txt'
print(f"Saving results to: {final_file}")

with open(final_file, 'w') as f: # Changed to 'w' to create a new file for each run
    f.write(f'Dataset: {dataset_name_suffix}\n')
    f.write(f'Model: {model_type}, Filters: {filters}\n')
    f.write(f'Timestamp: {ct}\n\n')
    
    f.write(f'Training Set Metrics:\n')
    f.write(f'  Dice (F1): {dice_train:.4f}\n')
    f.write(f'  mIoU: {miou_train:.4f}\n')
    f.write(f'  Precision: {precision_train:.4f}\n')
    f.write(f'  Recall: {recall_train:.4f}\n')
    f.write(f'  Accuracy: {accuracy_train:.4f}\n\n')

    f.write(f'Validation Set Metrics:\n')
    f.write(f'  Dice (F1): {dice_valid:.4f}\n')
    f.write(f'  mIoU: {miou_valid:.4f}\n')
    f.write(f'  Precision: {precision_valid:.4f}\n')
    f.write(f'  Recall: {recall_valid:.4f}\n')
    f.write(f'  Accuracy: {accuracy_valid:.4f}\n\n')

    f.write(f'Hold-out Test Set Metrics (from training data split):\n')
    f.write(f'  Dice (F1): {dice_test:.4f}\n')
    f.write(f'  mIoU: {miou_test:.4f}\n')
    f.write(f'  Precision: {precision_test:.4f}\n')
    f.write(f'  Recall: {recall_test:.4f}\n')
    f.write(f'  Accuracy: {accuracy_test:.4f}\n\n')

print('File done.')

# Optional: Predict on the final_test_ds (unlabeled data)
# This part is for generating masks for your actual test set.
# These masks can then be saved to disk if needed.
if final_test_ds:
    print("\nGenerating predictions for the final (unlabeled) test set...")
    final_test_predictions_list = []
    for batch_images in final_test_ds:
        preds = model.predict_on_batch(batch_images)
        final_test_predictions_list.append(preds)
    
    if final_test_predictions_list:
        final_test_predictions_np = np.concatenate(final_test_predictions_list, axis=0)
        print(f"Generated {final_test_predictions_np.shape[0]} predictions for the final test set.")
        # Example: Save the first prediction mask
        # if final_test_predictions_np.shape[0] > 0:
        #     plt.imshow(final_test_predictions_np[0, :, :, 0] > 0.5, cmap='gray')
        #     plt.title("Example Prediction from Final Test Set")
        #     plt.savefig(f'example_final_test_prediction_{ct}.png')
        #     print(f"Saved an example prediction mask to example_final_test_prediction_{ct}.png")
        # You can add code here to save all predicted masks, e.g., using tf.keras.utils.save_img or PIL.
    else:
        print("No predictions generated for the final test set (list was empty).")

else:
    print("\nNo final (unlabeled) test dataset provided or it was empty.")