In [None]:
# Importing the necessary libraries

import tensorflow as tf
import albumentations as albu
import numpy as np
import gc
import os
import cv2
import matplotlib.pyplot as plt
from keras.callbacks import CSVLogger
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import jaccard_score, precision_score, recall_score, accuracy_score, f1_score
from ModelArchitecture.DiceLoss import dice_metric_loss
from ModelArchitecture import DUCK_Net
from PIL import Image

In [None]:
# Checking the number of GPUs available

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# Memory-efficient data loading functions using TensorFlow data API

def get_file_paths(train_path, mask_path, test_path=None):
    """
    Get file paths instead of loading all images into memory
    
    Parameters:
    train_path: Path to training images
    mask_path: Path to mask images
    test_path: Path to test images (optional)
    
    Returns:
    train_files: List of training image paths
    mask_files: List of mask image paths
    test_files: List of test image paths (if test_path provided)
    """
    # Get all image files from directories
    train_files = sorted([os.path.join(train_path, f) for f in os.listdir(train_path) if f.endswith(('.jpg', '.png', '.jpeg', '.bmp', '.tif'))])
    mask_files = sorted([os.path.join(mask_path, f) for f in os.listdir(mask_path) if f.endswith(('.jpg', '.png', '.jpeg', '.bmp', '.tif'))])
    
    print(f"Found {len(train_files)} training images and {len(mask_files)} mask images")
    
    test_files = None
    if test_path:
        test_files = sorted([os.path.join(test_path, f) for f in os.listdir(test_path) if f.endswith(('.jpg', '.png', '.jpeg', '.bmp', '.tif'))])
        print(f"Found {len(test_files)} test images")
    
    return train_files, mask_files, test_files

def load_image(image_path):
    """Load and preprocess a single image"""
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.0
    return img

def load_mask(mask_path):
    """Load and preprocess a single mask"""
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_image(mask, channels=1)
    # Convert to binary
    mask = tf.cast(mask > 127, tf.float32)
    return mask

def create_dataset(image_paths, mask_paths=None, batch_size=4, shuffle=True, augment=False):
    """Create a TensorFlow dataset that loads images on-the-fly"""
    if mask_paths is not None:
        # Training/validation dataset with images and masks
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
        dataset = dataset.map(lambda x, y: (load_image(x), load_mask(y)), 
                             num_parallel_calls=tf.data.AUTOTUNE)
    else:
        # Test dataset with only images
        dataset = tf.data.Dataset.from_tensor_slices(image_paths)
        dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=1000)
    
    # Use prefetch to overlap data preprocessing and model execution
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return dataset

def augment_batch(images, masks):
    """Apply augmentations to a batch of images and masks"""
    # Convert to numpy for albumentations
    images_np = images.numpy()
    masks_np = masks.numpy()
    
    aug_images = []
    aug_masks = []
    
    # Process each image-mask pair in the batch
    for img, mask in zip(images_np, masks_np):
        aug = aug_train(image=img, mask=mask)
        aug_images.append(aug['image'])
        aug_masks.append(aug['mask'])
    
    return np.array(aug_images), np.array(aug_masks)

# Legacy function for compatibility with existing code
def load_custom_data(train_path, mask_path, test_path=None):
    """
    Return file paths instead of loaded images
    """
    return get_file_paths(train_path, mask_path, test_path)

In [None]:
# Setting the model parameters

# Image dimensions for the custom dataset (600x450)
img_height = 600
img_width = 450

learning_rate = 1e-4
seed_value = 58800
filters = 17  # Number of filters, the paper presents the results with 17 and 34
optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)

ct = datetime.now()

model_type = "DuckNet"

# Define paths for your dataset
data_root = '/workspaces/DUCK-Net/data'  # Update this path to your data directory
train_path = os.path.join(data_root, 'train')
mask_path = os.path.join(data_root, 'mask')
test_path = os.path.join(data_root, 'test')
output_path = os.path.join(data_root, 'predictions')

# Create output directory if it doesn't exist
os.makedirs(output_path, exist_ok=True)

# Paths for saving model and progress
progress_path = 'ProgressFull/custom_progress_csv_' + model_type + '_filters_' + str(filters) +  '_' + str(ct) + '.csv'
progressfull_path = 'ProgressFull/custom_progress_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.txt'
plot_path = 'ProgressFull/custom_progress_plot_' + model_type + '_filters_' + str(filters) + '_' + str(ct) + '.png'
model_path = 'ModelSaveTensorFlow/custom/' + model_type + '_filters_' + str(filters) + '_' + str(ct)

EPOCHS = 600
min_loss_for_saving = 0.2

In [None]:
# Get file paths instead of loading all images at once
train_files, mask_files, test_files = load_custom_data(train_path, mask_path, test_path)
print(f"Training files: {len(train_files)}, Mask files: {len(mask_files)}")
if test_files is not None:
    print(f"Test files: {len(test_files)}")

# Count number of samples for metrics calculation
num_train_samples = len(train_files)
num_test_samples = len(test_files) if test_files is not None else 0

In [None]:
# Split the data paths, seed for reproducibility
train_files_train, train_files_valid, mask_files_train, mask_files_valid = train_test_split(
    train_files, mask_files, test_size=0.2, shuffle=True, random_state=seed_value
)

print(f"Training set: {len(train_files_train)} images, Validation set: {len(train_files_valid)} images")

# Create datasets
train_dataset = create_dataset(train_files_train, mask_files_train, batch_size=4, shuffle=True)
valid_dataset = create_dataset(train_files_valid, mask_files_valid, batch_size=4, shuffle=False)

if test_files is not None:
    test_dataset = create_dataset(test_files, batch_size=4, shuffle=False)

In [None]:
# Defining the augmentations (keeping the existing code)
aug_train = albu.Compose([
    albu.HorizontalFlip(),
    albu.VerticalFlip(),
    albu.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),
    albu.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True),
])

# Create a TF function for augmentation
@tf.function
def augment_dataset(images, masks):
    # Use tf.py_function to call numpy-based augmentation
    [aug_images, aug_masks] = tf.py_function(
        augment_batch, [images, masks], [tf.float32, tf.float32]
    )
    # Ensure shapes are preserved
    aug_images.set_shape(images.shape)
    aug_masks.set_shape(masks.shape)
    return aug_images, aug_masks

In [None]:
# Creating the model

model = DUCK_Net.create_model(img_height=img_height, img_width=img_width, input_chanels=3, out_classes=1, starting_filters=filters)

In [19]:
# Compiling the model

model.compile(optimizer=optimizer, loss=dice_metric_loss)

In [None]:
# Training the model with memory-efficient data loading
# No augmentation - using default dataset

step = 0
min_val_loss = float('inf')

for epoch in range(0, EPOCHS):
    
    print(f'Training, epoch {epoch}')
    print('Learning Rate: ' + str(learning_rate))

    step += 1
    
    # For each epoch, we'll iterate through the dataset
    csv_logger = CSVLogger(progress_path, append=True, separator=';')
    
    # Train for one epoch using the original dataset without augmentation
    history = model.fit(
        train_dataset,  # Using the original dataset without augmentation
        epochs=1,
        validation_data=valid_dataset,
        verbose=1,
        callbacks=[csv_logger]
    )
    
    # Validate on the entire validation set
    val_loss = 0
    val_batches = 0
    
    for x_batch, y_batch in valid_dataset:
        pred_batch = model.predict(x_batch, verbose=0)
        batch_loss = dice_metric_loss(y_batch, pred_batch).numpy()
        val_loss += batch_loss
        val_batches += 1
    
    # Calculate average validation loss
    avg_val_loss = val_loss / val_batches
    print("Loss Validation: " + str(avg_val_loss))
    
    with open(progressfull_path, 'a') as f:
        f.write('epoch: ' + str(epoch) + '\nval_loss: ' + str(avg_val_loss) + '\n\n\n')
    
    # Save model if improved
    if min_loss_for_saving > avg_val_loss:
        min_loss_for_saving = avg_val_loss
        print("Saved model with val_loss: ", avg_val_loss)
        model.save(model_path)
    
    # Clear memory between epochs
    gc.collect()
    tf.keras.backend.clear_session()

In [None]:
# Computing metrics on training and validation sets with batch processing

print("Loading the best model")
model = tf.keras.models.load_model(model_path, custom_objects={'dice_metric_loss': dice_metric_loss})

# Collect predictions and ground truth in batches
y_true_train = []
y_pred_train = []
y_true_valid = []
y_pred_valid = []

print("Computing metrics for training set...")
for x_batch, y_batch in create_dataset(train_files_train, mask_files_train, batch_size=4, shuffle=False):
    pred_batch = model.predict(x_batch, verbose=0)
    
    # Convert to binary predictions
    pred_batch_binary = (pred_batch > 0.5).astype(np.float32)
    
    # Flatten and append
    y_true_train.extend(y_batch.numpy().flatten())
    y_pred_train.extend(pred_batch_binary.flatten())

print("Computing metrics for validation set...")
for x_batch, y_batch in create_dataset(train_files_valid, mask_files_valid, batch_size=4, shuffle=False):
    pred_batch = model.predict(x_batch, verbose=0)
    
    # Convert to binary predictions
    pred_batch_binary = (pred_batch > 0.5).astype(np.float32)
    
    # Flatten and append
    y_true_valid.extend(y_batch.numpy().flatten())
    y_pred_valid.extend(pred_batch_binary.flatten())

# Calculate metrics
y_true_train = np.array(y_true_train, dtype=bool)
y_pred_train = np.array(y_pred_train, dtype=bool)
y_true_valid = np.array(y_true_valid, dtype=bool)
y_pred_valid = np.array(y_pred_valid, dtype=bool)

dice_train = f1_score(y_true_train, y_pred_train)
dice_valid = f1_score(y_true_valid, y_pred_valid)

miou_train = jaccard_score(y_true_train, y_pred_train)
miou_valid = jaccard_score(y_true_valid, y_pred_valid)

precision_train = precision_score(y_true_train, y_pred_train)
precision_valid = precision_score(y_true_valid, y_pred_valid)

recall_train = recall_score(y_true_train, y_pred_train)
recall_valid = recall_score(y_true_valid, y_pred_valid)

accuracy_train = accuracy_score(y_true_train, y_pred_train)
accuracy_valid = accuracy_score(y_true_valid, y_pred_valid)

# Print and save metrics
print(f"Dice score - Train: {dice_train:.4f}, Validation: {dice_valid:.4f}")
print(f"IoU score  - Train: {miou_train:.4f}, Validation: {miou_valid:.4f}")
print(f"Precision  - Train: {precision_train:.4f}, Validation: {precision_valid:.4f}")
print(f"Recall     - Train: {recall_train:.4f}, Validation: {recall_valid:.4f}")
print(f"Accuracy   - Train: {accuracy_train:.4f}, Validation: {accuracy_valid:.4f}")

final_file = 'results_' + model_type + '_' + str(filters) + '_custom.txt'
with open(final_file, 'a') as f:
    f.write('Custom Dataset\n\n')
    f.write(f'dice_train: {dice_train:.4f} dice_valid: {dice_valid:.4f}\n\n')
    f.write(f'miou_train: {miou_train:.4f} miou_valid: {miou_valid:.4f}\n\n')
    f.write(f'precision_train: {precision_train:.4f} precision_valid: {precision_valid:.4f}\n\n')
    f.write(f'recall_train: {recall_train:.4f} recall_valid: {recall_valid:.4f}\n\n')
    f.write(f'accuracy_train: {accuracy_train:.4f} accuracy_valid: {accuracy_valid:.4f}\n\n\n\n')

In [None]:
# Predicting on test images and saving results with batch processing

if test_files is not None:
    print(f"Predicting on {len(test_files)} test images...")
    
    # Get the list of test image filenames
    test_images = [os.path.basename(f) for f in test_files]
    
    # Create dataset for test images
    test_dataset = create_dataset(test_files, batch_size=4, shuffle=False)
    
    # Process batches
    batch_idx = 0
    for batch in test_dataset:
        # Get predictions for this batch
        batch_preds = model.predict(batch, verbose=0)
        
        # Process each prediction in the batch
        for i in range(batch_preds.shape[0]):
            # Get the image index in the overall dataset
            img_idx = batch_idx * 4 + i
            
            # Skip if we've exceeded the number of test images
            if img_idx >= len(test_images):
                break
                
            # Convert prediction to binary mask
            pred_mask = (batch_preds[i] > 0.5).astype(np.uint8) * 255
            
            # Save the mask
            output_filename = os.path.join(output_path, test_images[img_idx])
            cv2.imwrite(output_filename, pred_mask)
        
        # Increment batch index
        batch_idx += 1
        
    print(f"Predictions saved to {output_path}")
else:
    print("No test data provided")

# Visualization of Results

Let's visualize some of the test predictions alongside the original images:

In [None]:
# Visualize a few test images and their predictions with batch processing
if test_files is not None:
    num_samples = min(5, len(test_files))
    plt.figure(figsize=(12, 4*num_samples))
    
    # Create a dataset with just the first few test images
    sample_test_files = test_files[:num_samples]
    sample_test_dataset = create_dataset(sample_test_files, batch_size=num_samples, shuffle=False)
    
    # Get a single batch containing all samples
    for test_batch in sample_test_dataset:
        # This will execute only once since we created a dataset with exactly one batch
        sample_images = test_batch.numpy()
        sample_preds = model.predict(test_batch, verbose=0)
        
        for i in range(num_samples):
            # Original image
            plt.subplot(num_samples, 2, i*2+1)
            plt.imshow(sample_images[i])
            plt.title(f"Test Image {i+1}")
            plt.axis("off")
            
            # Prediction
            plt.subplot(num_samples, 2, i*2+2)
            plt.imshow(sample_preds[i, :, :, 0] > 0.5, cmap='gray')
            plt.title(f"Prediction {i+1}")
            plt.axis("off")
        
        break  # Only process the first batch
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_path, "sample_predictions.png"))
    plt.show()