In [None]:
# Import required libraries
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam

# Import custom DataGenerator and DenseUNetWithSpatialAttention
from DataGenerators import DataGenerator  
from Models.DenseUNetWithSpatialAttention import DenseUNetWithSpatialAttention


In [None]:
# Define paths and parameters
DATA_PATH = "LIDC" 
MODEL_SAVE_PATH = "LIDC/saved_models"  
BATCH_SIZE = 8
IMG_HEIGHT = 512
IMG_WIDTH = 512
EPOCHS = 50
LEARNING_RATE = 1e-4
VALIDATION_SPLIT = 0.2
TEST_SPLIT = 0.1

In [None]:
# Function to load and split data using DataGenerator class
def load_data_with_generators():
    # Define image and mask directories
    image_dir = os.path.join(DATA_PATH, 'images')
    mask_dir = os.path.join(DATA_PATH, 'masks')

    # Initialize data generators for training, validation, and test sets
    data_generator = DataGenerator(image_dir, mask_dir, img_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE, test_split=TEST_SPLIT, val_split=VALIDATION_SPLIT)

    # Generate the train, validation, and test generators
    train_generator = data_generator.get_train_generator()
    val_generator = data_generator.get_val_generator()
    test_generator = data_generator.get_test_generator()

    return train_generator, val_generator, test_generator

In [None]:
# Function to build, compile and train the model
def train_model():
    # Load and prepare dataset using the DataGenerator class
    train_generator, val_generator, test_generator = load_data_with_generators()

    # Initialize the Dense U-Net with Spatial Attention model
    model = DenseUNetWithSpatialAttention(input_size=(IMG_HEIGHT, IMG_WIDTH, 3), output_channels=1).build_model()

    # Compile the model
    model.compile(optimizer=Adam(learning_rate=LEARNING_RATE), 
                  loss='binary_crossentropy', 
                  metrics=['accuracy'])

    # Define callbacks for early stopping and model checkpointing
    callbacks = [
        ModelCheckpoint(MODEL_SAVE_PATH, monitor='val_loss', save_best_only=True, mode='min', verbose=1),
        EarlyStopping(monitor='val_loss', patience=10, verbose=1, restore_best_weights=True)
    ]

    # Train the model
    model.fit(train_generator,
              validation_data=val_generator,
              epochs=EPOCHS,
              steps_per_epoch=len(train_generator),
              validation_steps=len(val_generator),
              callbacks=callbacks,
              verbose=1)

    # Evaluate the model on the test set
    test_loss, test_accuracy = model.evaluate(test_generator, steps=len(test_generator), verbose=1)
    print(f"Test Loss: {test_loss}")
    print(f"Test Accuracy: {test_accuracy}")

In [None]:
train_model()