In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.python.keras.backend as K
import gc
from sklearn.metrics import precision_recall_curve, precision_score as sklearn_precision_score, \
    recall_score as sklearn_recall_score, roc_curve
from tensorflow.keras.losses import BinaryFocalCrossentropy

# GPU Setup
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)
# Optional: Use MirroredStrategy for multi-GPU training
if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"Training with {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()  # Default strategy
    print("Training with single GPU or CPU")

# Constants
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 64  # Can be larger with GPU compared to TPU

# Data constants (same as before)
train_samples = 11200
train_tfrecords = [f'/content/drive/MyDrive/DermaLyticsAI/small dataset/train.tfrecord']

val_samples = 4800
val_tfrecords = [f'/content/drive/MyDrive/DermaLyticsAI/small dataset/validation.tfrecord']

total_test_samples = 4000
test_tfrecords = [f'/content/drive/MyDrive/DermaLyticsAI/small dataset/test.tfrecord']

checkpoint_path = f'/content/drive/MyDrive/DermaLyticsAI/small dataset/modelCheckpoint.weights.h5'
log_dir = f'/content/drive/MyDrive/DermaLyticsAI/small dataset/logs'
modelSave_path = f'/content/drive/MyDrive/DermaLyticsAI/small dataset/model/derma_model.h5'
modelKerasSave_path = f'/content/drive/MyDrive/DermaLyticsAI/small dataset/model/derma_model_keras.keras'

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def parse_function(proto, image_size, is_training=True):
    """Parse a single example from TFRecord with additional error handling."""
    # Define the feature description
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([1], tf.float32),
        "metadata": tf.io.FixedLenFeature([11], tf.float32),
        "mask": tf.io.FixedLenFeature([], tf.string),
    }

    # Parse with error handling
    example = tf.io.parse_single_example(proto, feature_description)

    # Decode and process the image with error clipping
    image = tf.io.decode_jpeg(example["image"], channels=3)
    image = tf.image.resize(image, image_size)
    image = tf.clip_by_value(tf.cast(image, tf.float32) / 255.0, 0.0, 1.0)  # Ensure proper range

    # Decode and process the mask with error clipping
    mask = tf.io.decode_jpeg(example["mask"], channels=3)
    mask = tf.image.resize(mask, image_size)
    mask = tf.clip_by_value(tf.cast(mask, tf.float32) / 255.0, 0.0, 1.0)  # Ensure proper range

    if is_training:
        # Generate random transformation parameters
        seed = tf.random.uniform([2], minval=0, maxval=100, dtype=tf.int32)

        # Apply identical flips using the same seed
        image = tf.image.stateless_random_flip_left_right(image, seed)
        mask = tf.image.stateless_random_flip_left_right(mask, seed)
        image = tf.image.stateless_random_flip_up_down(image, seed)
        mask = tf.image.stateless_random_flip_up_down(mask, seed)

        # Random rotation
        flipNumber = tf.random.uniform([], 0, 4, dtype=tf.int32)
        image = tf.image.rot90(image, k=flipNumber)
        mask = tf.image.rot90(mask, k=flipNumber)

    # Extract and normalize metadata
    metadata = example["metadata"]
    # Clip metadata to reasonable ranges to prevent extreme values
    metadata = tf.clip_by_value(metadata, -100.0, 100.0)

    # Make sure label is valid binary
    label = tf.clip_by_value(example["label"], 0.0, 1.0)

    # Return using a dictionary structure
    return ({"image_input": image, "mask_input": mask, "metadata_input": metadata}, label)

In [None]:
def load_tfrecord_dataset(tfrecord_paths, batch_size, image_size, epochs=1, is_training=True):
    """Load TFRecord dataset with enhanced stability and error handling."""
    # Use the parse function
    parse_fn = lambda x: parse_function(x, image_size, is_training)

    # Create the dataset
    dataset = tf.data.TFRecordDataset(
        tfrecord_paths,
        num_parallel_reads=tf.data.AUTOTUNE
    )

    dataset = dataset.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)

    if is_training:
        dataset = dataset.shuffle(1000)  # Shuffle before batching

    # Batch the dataset
    dataset = dataset.batch(batch_size)

    # Apply repeat after batching
    if is_training:
        dataset = dataset.repeat(epochs)

    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    # Calculate steps
    if is_training:
        steps = train_samples // batch_size
    else:
        if tfrecord_paths[0] == val_tfrecords[0]:
            steps = val_samples // batch_size
        else:
            steps = total_test_samples // batch_size

    return dataset, steps

In [None]:
def build_model(image_size=(224, 224)):
    """Build model with enhanced numerical stability."""
    # Set up kernel initializers for stability
    kernel_init = tf.keras.initializers.GlorotNormal(seed=42)

    # Define the inputs
    image_input = tf.keras.Input(shape=(*image_size, 3), name="image_input")
    mask_input = tf.keras.Input(shape=(*image_size, 3), name="mask_input")
    metadata_input = tf.keras.Input(shape=(11,), name="metadata_input")

    # Attention mechanism with stabilized initialization
    attention = tf.keras.layers.Conv2D(
        1, (1, 1),
        activation='sigmoid',
        kernel_initializer=kernel_init
    )(mask_input)
    modulated_image = tf.keras.layers.Multiply()([image_input, attention])

    # Load EfficientNetV2B3 with weight initialization for stability
    base_model = tf.keras.applications.EfficientNetV2B3(
        include_top=False,
        input_tensor=modulated_image,
        weights=None,  # Start with random weights
        classes=2
    )

    # Freeze initial layers to stabilize training
    for layer in base_model.layers[:50]:  # Freeze early layers
        layer.trainable = False

    # Make later layers trainable
    for layer in base_model.layers[50:]:
        layer.trainable = True

    # Add attention mechanism with stable initialization
    attention = tf.keras.layers.Conv2D(
        base_model.output.shape[-1],
        (1, 1),
        activation='sigmoid',
        kernel_initializer=kernel_init
    )(base_model.output)

    weighted_features = tf.keras.layers.Multiply()([base_model.output, attention])

    # Process image features with batch normalization for stability
    x = tf.keras.layers.GlobalAveragePooling2D()(weighted_features)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(0.2)(x)  # Add dropout for regularization

    # Process mask with a small CNN to extract features
    mask_features = tf.keras.layers.Conv2D(
        16, (3, 3),
        padding='same',
        activation='relu',
        kernel_initializer=kernel_init
    )(mask_input)
    mask_features = tf.keras.layers.BatchNormalization()(mask_features)
    mask_features = tf.keras.layers.MaxPooling2D()(mask_features)

    mask_features = tf.keras.layers.Conv2D(
        32, (3, 3),
        padding='same',
        activation='relu',
        kernel_initializer=kernel_init
    )(mask_features)
    mask_features = tf.keras.layers.BatchNormalization()(mask_features)
    mask_features = tf.keras.layers.GlobalAveragePooling2D()(mask_features)
    mask_features = tf.keras.layers.BatchNormalization()(mask_features)

    # Split metadata with more stable processing
    demographic_features = tf.keras.layers.Lambda(lambda x: x[:, :3])(metadata_input)
    location_features = tf.keras.layers.Lambda(lambda x: x[:, 3:])(metadata_input)

    # Process location features
    location_encoded = tf.keras.layers.Dense(
        16,
        activation="relu",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(location_features)
    location_encoded = tf.keras.layers.BatchNormalization()(location_encoded)
    location_encoded = tf.keras.layers.Dropout(0.1)(location_encoded)

    # Process demographic features
    demographic_encoded = tf.keras.layers.Dense(
        16,
        activation="relu",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(demographic_features)
    demographic_encoded = tf.keras.layers.BatchNormalization()(demographic_encoded)
    demographic_encoded = tf.keras.layers.Dropout(0.1)(demographic_encoded)

    # Combine metadata features
    metadata_features = tf.keras.layers.Concatenate()([location_encoded, demographic_encoded])
    metadata_features = tf.keras.layers.BatchNormalization()(metadata_features)

    # Combine all features with BatchNormalization
    combined = tf.keras.layers.Concatenate()([x, metadata_features, mask_features])
    combined = tf.keras.layers.BatchNormalization()(combined)

    # First dense block with residual connection and more regularization
    block1 = tf.keras.layers.Dense(
        256,
        activation="relu",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(combined)
    block1 = tf.keras.layers.BatchNormalization()(block1)
    block1 = tf.keras.layers.Dropout(0.3)(block1)

    # Residual connection with the same architecture
    block1_res = tf.keras.layers.Dense(
        256,
        activation="relu",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(block1)
    block1_res = tf.keras.layers.BatchNormalization()(block1_res)
    block1_res = tf.keras.layers.Dropout(0.3)(block1_res)

    # Add residual connection for improved gradient flow
    block1 = tf.keras.layers.Add()([block1, block1_res])
    block1 = tf.keras.layers.BatchNormalization()(block1)

    # Second dense block
    block2 = tf.keras.layers.Dense(
        128,
        activation="relu",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(block1)
    block2 = tf.keras.layers.BatchNormalization()(block2)
    block2 = tf.keras.layers.Dropout(0.2)(block2)

    # Final classification layer with stable initialization
    output = tf.keras.layers.Dense(
        1,
        activation="sigmoid",
        kernel_initializer=kernel_init,
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    )(block2)

    model = tf.keras.Model(inputs=[image_input, metadata_input, mask_input], outputs=output)

    return model

In [None]:
# Define callbacks for training with additional stability monitoring
early_stopping_auc = tf.keras.callbacks.EarlyStopping(
    monitor="val_AUC",
    patience=5,
    restore_best_weights=True,
    mode="max"
)
early_stopping_loss = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=7,
    restore_best_weights=True,
    mode="min"
)
terminate_on_nan = tf.keras.callbacks.TerminateOnNaN()

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    update_freq='epoch'
)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_AUC',
    mode='max',
    verbose=1
)

# Load datasets
train_dataset, train_steps = load_tfrecord_dataset(
    train_tfrecords,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    epochs=40,
    is_training=True
)

val_dataset, val_steps = load_tfrecord_dataset(
    val_tfrecords,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    epochs=1,  # Only need one epoch for validation
    is_training=False
)

print(f"Steps per epoch: {train_steps}, Validation steps: {val_steps}")

# Build and compile model - use strategy scope only if using multiple GPUs
if len(gpus) > 1:
    with strategy.scope():
        model = build_model(IMAGE_SIZE)

        # Learning rate schedule
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=1e-4,  # Lower initial learning rate
            decay_steps=train_steps * 5,
            decay_rate=0.95,
            staircase=True
        )

        # Optimizer with enhanced numerical stability
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr_schedule,
            clipnorm=1.0,        # Gradient clipping
            epsilon=1e-7         # Numerical stability
        )

        # Compile model
        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
            metrics=[
                tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5),
                tf.keras.metrics.AUC(name="AUC", curve='ROC'),
                tf.keras.metrics.Precision(name="precision", thresholds=0.5),
                tf.keras.metrics.Recall(name="recall", thresholds=0.5),
            ],
        )
else:
    # Single GPU or CPU training
    model = build_model(IMAGE_SIZE)

    # Learning rate schedule
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-4,
        decay_steps=train_steps * 5,
        decay_rate=0.95,
        staircase=True
    )

    # Optimizer with enhanced numerical stability
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=lr_schedule,
        clipnorm=1.0,
        epsilon=1e-7
    )

    # Compile model
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5),
            tf.keras.metrics.AUC(name="AUC", curve='ROC'),
            tf.keras.metrics.Precision(name="precision", thresholds=0.5),
            tf.keras.metrics.Recall(name="recall", thresholds=0.5),
        ],
    )

# Print model summary
model.summary()

In [None]:
# Phase 1: Initial training with standard loss
print("PHASE 1: Main training phase")
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=40,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    callbacks=[
        early_stopping_auc,
        early_stopping_loss,
        checkpoint_callback,
        tensorboard_callback,
        terminate_on_nan
    ],
    verbose=1
)

# Force garbage collection to clear memory
gc.collect()
K.clear_session()

In [None]:
# Save the trained model
try:
    model.save(modelSave_path)
    model.save(modelKerasSave_path)
    print(f"Model saved successfully to {modelSave_path}")
except Exception as e:
    print(f"Error saving model: {e}")
    try:
        # Save weights only if full model save fails
        model.save_weights(f'{modelSave_path}_weights')
        print(f"Model weights saved to {modelSave_path}_weights")
    except Exception as e:
        print(f"Error saving model weights: {e}")

In [None]:
# Evaluation phase
# Load test dataset
test_dataset, test_steps = load_tfrecord_dataset(
    test_tfrecords,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    epochs=1,
    is_training=False
)

In [None]:
# Evaluate with default threshold
print("Evaluating with default threshold (0.5):")
test_results = model.evaluate(
    test_dataset,
    steps=test_steps,
    verbose=1
)
print("Test Results:", dict(zip(model.metrics_names, test_results)))

In [None]:
# Create dataset for threshold tuning
threshold_dataset, threshold_steps = load_tfrecord_dataset(
    val_tfrecords,
    batch_size=32,  # Smaller batch size for prediction
    image_size=IMAGE_SIZE,
    epochs=1,
    is_training=False
)

# Collect predictions for threshold optimization
all_preds = []
all_labels = []

print("Collecting predictions for threshold optimization...")
# Predict in batches
for i, (inputs, labels) in enumerate(threshold_dataset):
    if i >= val_samples // 32:  # Adjusted for batch size
        break

    preds = model.predict_on_batch(inputs)

    # Convert to numpy and store
    all_preds.extend(preds.numpy().flatten())
    all_labels.extend(labels.numpy().flatten())

    # Periodically clear memory
    if i % 10 == 0:
        gc.collect()

# Find optimal threshold
print("Calculating optimal threshold...")
if len(all_preds) > 0 and len(all_labels) > 0:
    try:
        fpr, tpr, thresholds = roc_curve(all_labels, all_preds)
        j_scores = tpr - fpr
        optimal_idx = np.argmax(j_scores)
        optimal_threshold = thresholds[optimal_idx]
        optimal_threshold = float(optimal_threshold)

        # Calculate metrics with various thresholds
        thresholds_to_try = [0.3, 0.4, 0.5, 0.6, 0.7, optimal_threshold]
        for threshold in thresholds_to_try:
            binary_preds = (np.array(all_preds) >= threshold).astype(int)
            precision = sklearn_precision_score(all_labels, binary_preds)
            recall = sklearn_recall_score(all_labels, binary_preds)
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

            print(f"Threshold: {threshold:.4f}, Precision: {precision:.4f}, "
                  f"Recall: {recall:.4f}, F1: {f1:.4f}")

        print(f"Optimal threshold: {optimal_threshold:.4f}")
    except Exception as e:
        print(f"Error calculating optimal threshold: {e}")
        optimal_threshold = 0.5  # Use default
else:
    print("No valid predictions collected for threshold optimization")
    optimal_threshold = 0.5  # Use default

# Clear memory before final phase
gc.collect()
K.clear_session()

In [None]:
# Phase 2: Final training with focal loss
print("PHASE 2: Final training with Focal Loss and optimized parameters")

# Set a moderate learning rate for final training
K.set_value(model.optimizer.learning_rate, 5e-6)  # Very conservative learning rate

# Recompile with focal loss and custom metrics
if len(gpus) > 1:
    with strategy.scope():
        model.compile(
            optimizer=model.optimizer,
            loss=BinaryFocalCrossentropy(alpha=0.75, gamma=2.0),  # Adjusted focal loss
            metrics=[
                tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5),
                tf.keras.metrics.AUC(name="AUC", curve='ROC'),
                tf.keras.metrics.Precision(name="precision", thresholds=0.5),
                tf.keras.metrics.Recall(name="recall", thresholds=0.5),
                tf.keras.metrics.Precision(name="opt_precision", thresholds=optimal_threshold),
                tf.keras.metrics.Recall(name="opt_recall", thresholds=optimal_threshold),
            ],
        )
else:
    model.compile(
        optimizer=model.optimizer,
        loss=BinaryFocalCrossentropy(alpha=0.75, gamma=2.0),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5),
            tf.keras.metrics.AUC(name="AUC", curve='ROC'),
            tf.keras.metrics.Precision(name="precision", thresholds=0.5),
            tf.keras.metrics.Recall(name="recall", thresholds=0.5),
            tf.keras.metrics.Precision(name="opt_precision", thresholds=optimal_threshold),
            tf.keras.metrics.Recall(name="opt_recall", thresholds=optimal_threshold),
        ],
    )

In [None]:
# Final training phase with reduced epochs
final_history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10,  # Reduced epochs for final phase
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    callbacks=[
        early_stopping_auc,
        early_stopping_loss,
        checkpoint_callback,
        tensorboard_callback,
        terminate_on_nan
    ],
    verbose=1
)

# Save the final model
try:
    model.save(modelSave_path)
    model.save(modelKerasSave_path)
    print(f"Final model saved successfully to {modelSave_path}")
except Exception as e:
    print(f"Error saving final model: {e}")
    try:
        model.save_weights(f'{modelSave_path}_final_weights')
        print(f"Final model weights saved to {modelSave_path}_final_weights")
    except:
        print("Could not save final model weights")

# Final evaluation
print("\nFinal evaluation results:")
test_results = model.evaluate(test_dataset, steps=test_steps)
print("Test Results:", dict(zip(model.metrics_names, test_results)))

print(f"Training complete. Optimal threshold for prediction: {optimal_threshold:.4f}")