In [None]:
import numpy as np
import pandas as pd

data_df = pd.read_csv('/kaggle/input/grand-xray-slam-division-b/train2.csv')

data_df['Image_path'] = '/kaggle/input/grand-xray-slam-division-b/train2/' + data_df['Image_name']
cols = ['Image_path'] + [c for c in data_df.columns if c != 'Image_path']
data_df = data_df[cols]

broken_images = ['/kaggle/input/grand-xray-slam-division-b/train2/00043046_001_001.jpg',
 '/kaggle/input/grand-xray-slam-division-b/train2/00052495_001_001.jpg',
 '/kaggle/input/grand-xray-slam-division-b/train2/00056890_001_001.jpg']
data_df = data_df[~data_df['Image_path'].isin(broken_images)].reset_index(drop=True)

data_df = data_df.drop('Image_name', axis=1)
data_df = data_df.drop('Study', axis=1)
data_df = data_df.drop('Patient_ID', axis=1)

data_df['Age'] = data_df['Age'].fillna(data_df['Age'].mean())
data_df['Sex'] = data_df['Sex'].fillna(data_df['Sex'].mode()[0])

data_df.columns = data_df.columns.str.replace(' ','_')

data_df.head()

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())  # Should return True
print(torch.cuda.device_count())  # Should be >=1
print(torch.cuda.get_device_name(0))  # Name of GPU
print(device)

In [None]:
num_cols = []
for cols in data_df.columns:
    if data_df[cols].dtype == np.float64 or data_df[cols].dtype == np.int64:
        num_cols.append(cols)

print('Numerical columns: ',num_cols, '\n')

cat_cols = []
for cols in data_df.columns:
    if data_df[cols].dtype == object:
        if cols == 'Image_path':
            continue
        else:
            cat_cols.append(cols)

print('Object columns: ',cat_cols)

# **Ensemble training**

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import pandas as pd
import numpy as np
import os

# GPU optimization settings
tf.config.optimizer.set_jit(True)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.set_logical_device_configuration(
                gpu,
                [tf.config.LogicalDeviceConfiguration(memory_limit=15000)])
        print(f"Found {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(e)

# ------------------- CONFIGURATION -------------------
IMG_SIZE = (224, 224)
BATCH_SIZE = 64
EPOCHS = 10
AUTOTUNE = tf.data.AUTOTUNE

# Paths to your saved weights (if available)
WEIGHTS_PATHS = {
    'resnet50': '/kaggle/input/resnet-model/tensorflow2/default/1/best_resnet_model.h5',
    'efficientnet': None,  # Set to your path if available
    'densenet': None       # Set to your path if available
}

from sklearn.model_selection import train_test_split

label_cols = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
              'Enlarged_Cardiomediastinum', 'Fracture', 'Lung_Lesion', 
              'Lung_Opacity', 'No_Finding', 'Pleural_Effusion', 'Pleural_Other', 
              'Pneumonia', 'Pneumothorax', 'Support_Devices']

# Assume data_df is already loaded
print(f"Using dataframe with {len(data_df)} samples")

train_seqs, val_seqs = train_test_split(data_df, test_size=0.2, random_state=42)
train_seqs = train_seqs.reset_index(drop=True)
val_seqs = val_seqs.reset_index(drop=True)

# threshold = 20% of total samples
small_classes = [
    col for col in label_cols 
    if train_seqs[col].sum() / len(data_df) < 0.2
]

print("Small classes:", small_classes)

# Compute label weights
class_counts = data_df[label_cols].sum().values
total_samples = len(data_df)
pos_weights = np.maximum(class_counts, 1)
neg_weights = total_samples - pos_weights
label_weights = neg_weights / pos_weights
label_weights = np.clip(label_weights, 0.5, 3.0)
label_weights = tf.constant(label_weights, dtype=tf.float32)

print("Label weights:", label_weights.numpy())

# ------------------- LOSS FUNCTION -------------------
@tf.function(jit_compile=True)
def weighted_bce(y_true, y_pred):
    """Stable weighted binary crossentropy"""
    bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    weighted_bce = bce * label_weights
    return tf.reduce_mean(weighted_bce)

# ------------------- DATA PIPELINE -------------------
@tf.function
def decode_and_process_train(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE, method='bilinear')

    # Base augmentations
    img = tf.image.random_flip_left_right(img)
    if tf.random.uniform([]) < 0.5:
        k = tf.random.uniform([], -1, 2, dtype=tf.int32)
        img = tf.image.rot90(img, k=k)
    img = tf.image.random_brightness(img, 0.1)
    img = tf.image.random_contrast(img, 0.85, 1.15)

    # Extra augmentations for smaller classes
    if tf.reduce_any(tf.greater(tf.gather(label, 
        [label_cols.index(c) for c in small_classes]), 0)):
        img = tf.image.random_saturation(img, 0.8, 1.3)
        img = tf.image.random_hue(img, 0.05)
        img = tf.image.random_flip_up_down(img)
        if tf.random.uniform([]) < 0.3:
            noise = tf.random.normal(tf.shape(img), mean=0.0, stddev=5.0)
            img = tf.clip_by_value(img + noise, 0, 255)

    img = tf.keras.applications.resnet50.preprocess_input(img)
    return img, label

@tf.function
def decode_and_process_val(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE, method='bilinear')
    img = tf.keras.applications.resnet50.preprocess_input(img)
    return img, label

# Prepare data
train_paths = train_seqs['Image_path'].values
train_labels = train_seqs[label_cols].values.astype("float32")
val_paths = val_seqs['Image_path'].values
val_labels = val_seqs[label_cols].values.astype("float32")

print(f"Training samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")

# Build datasets
train_ds = (
    tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    .shuffle(3000, reshuffle_each_iteration=True)
    .map(decode_and_process_train, num_parallel_calls=AUTOTUNE, deterministic=False)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=3)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
    .map(decode_and_process_val, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=2)
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(train_paths) // BATCH_SIZE}")

# ------------------- BUILD INDIVIDUAL MODELS -------------------
print("\n" + "="*50)
print("Building Ensemble Models...")
print("="*50)

# Mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision: {policy.name}")

def build_model(base_architecture, name):
    """Build a model with the given base architecture"""
    if base_architecture == 'resnet50':
        base_model = tf.keras.applications.ResNet50(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights=None
        )
    elif base_architecture == 'efficientnet':
        base_model = tf.keras.applications.EfficientNetB3(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights='imagenet'
        )
    elif base_architecture == 'densenet':
        base_model = tf.keras.applications.DenseNet121(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights='imagenet'
        )
    
    # Build head
    x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    x = tf.keras.layers.Dropout(0.3)(x)
    logits = tf.keras.layers.Dense(
        len(label_cols),
        activation=None,
        dtype='float32',
        name=f'{name}_output'
    )(x)
    
    model = tf.keras.Model(inputs=base_model.input, outputs=logits, name=name)
    
    # Make all layers trainable
    base_model.trainable = True
    
    return model

# Build the three models
model_resnet = build_model('resnet50', 'resnet50')
model_efficientnet = build_model('efficientnet', 'efficientnet')
model_densenet = build_model('densenet', 'densenet')

models_list = [model_resnet, model_efficientnet, model_densenet]
model_names = ['resnet50', 'efficientnet', 'densenet']

# Load pre-trained weights if available
if WEIGHTS_PATHS['resnet50'] and os.path.exists(WEIGHTS_PATHS['resnet50']):
    model_resnet.load_weights(WEIGHTS_PATHS['resnet50'])
    print(f"✓ Loaded ResNet50 weights from: {WEIGHTS_PATHS['resnet50']}")

print("\n" + "="*50)
print("Model Summary:")
for model, name in zip(models_list, model_names):
    trainable_count = sum([tf.size(w).numpy() for w in model.trainable_weights])
    print(f"{name}: {trainable_count:,} trainable parameters")
print("="*50)

# ------------------- BUILD ENSEMBLE MODEL -------------------
print("\n" + "="*50)
print("Building Ensemble Model...")
print("="*50)

# Input layer
input_layer = tf.keras.Input(shape=IMG_SIZE + (3,), name='ensemble_input')

# Get predictions from all three models
pred1 = model_resnet(input_layer)
pred2 = model_efficientnet(input_layer)
pred3 = model_densenet(input_layer)

# Average predictions
ensemble_output = tf.keras.layers.Average(name='ensemble_average')([pred1, pred2, pred3])

# Create ensemble model
ensemble_model = tf.keras.Model(inputs=input_layer, outputs=ensemble_output, name='ensemble')

print("✓ Ensemble model created")
print(f"Total models in ensemble: {len(models_list)}")

# ------------------- COMPILE ENSEMBLE MODEL -------------------
print("\n" + "="*50)
print("Compiling ensemble model...")
print("="*50)

fine_tune_lr = 1e-5

optimizer = tf.keras.optimizers.AdamW(
    learning_rate=fine_tune_lr,
    weight_decay=1e-6,
    clipnorm=1.0
)

ensemble_model.compile(
    optimizer=optimizer,
    loss=weighted_bce,
    metrics=[
        tf.keras.metrics.AUC(name="auc", from_logits=True),
        tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5)
    ],
    jit_compile=True
)

print(f"✓ Compiled with learning rate: {fine_tune_lr:.2e}")

# ------------------- CALLBACKS -------------------
# Clean up any existing checkpoint files
checkpoint_path = 'ensemble_model.weights.h5'
if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)
    print(f"Removed existing checkpoint: {checkpoint_path}")

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=3,
    restore_best_weights=True,
    mode='max',
    verbose=1
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=2,
    verbose=1,
    min_lr=1e-7
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_auc',
    mode='max',
    save_best_only=True,
    save_weights_only=True,  # Save weights only to avoid HDF5 issues
    verbose=1
)

class TimingCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        import time
        self.epoch_start = time.time()
    
    def on_epoch_end(self, epoch, logs=None):
        import time
        epoch_time = time.time() - self.epoch_start
        print(f"\nEpoch time: {epoch_time/60:.1f} minutes")

timing_cb = TimingCallback()

# ------------------- TRAIN ENSEMBLE -------------------
print("\n" + "="*50)
print("Training ensemble model...")
print("="*50 + "\n")

history = ensemble_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[early_stopping, reduce_lr, checkpoint, timing_cb],
    verbose=1
)

# ------------------- RESULTS -------------------
print("\n" + "="*50)
print("Ensemble training completed!")
print("="*50)
print(f"Best validation AUC: {max(history.history['val_auc']):.4f}")
print(f"Best validation loss: {min(history.history['val_loss']):.4f}")
print(f"Final learning rate: {ensemble_model.optimizer.learning_rate.numpy():.2e}")

# Save final ensemble model
final_model_path = 'final_ensemble_model.h5'
if os.path.exists(final_model_path):
    os.remove(final_model_path)
ensemble_model.save_weights(final_model_path)
print(f"\n✓ Ensemble model weights saved to: {final_model_path}")

# ------------------- INFERENCE FUNCTION -------------------
def predict_ensemble(image_paths):
    """
    Make predictions using the ensemble model
    
    Args:
        image_paths: List of image file paths
    
    Returns:
        predictions: Probabilities for each class (sigmoid applied to logits)
    """
    # Create dataset
    ds = (
        tf.data.Dataset.from_tensor_slices(image_paths)
        .map(lambda x: decode_and_process_val(x, tf.zeros(len(label_cols))))
        .batch(BATCH_SIZE)
        .prefetch(buffer_size=2)
    )
    
    # Get logits from ensemble
    logits = ensemble_model.predict(ds, verbose=1)
    
    # Convert to probabilities
    predictions = tf.sigmoid(logits).numpy()
    
    return predictions

print("\n" + "="*50)
print("Use predict_ensemble(image_paths) for inference")
print("="*50)

# Optional: Clear memory
import gc
gc.collect()
tf.keras.backend.clear_session()

# **Normal ensemble finetune**

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import pandas as pd
import numpy as np
import os

# GPU optimization settings
tf.config.optimizer.set_jit(True)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.set_logical_device_configuration(
                gpu,
                [tf.config.LogicalDeviceConfiguration(memory_limit=15000)])
        print(f"Found {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(e)

# ------------------- CONFIGURATION -------------------
IMG_SIZE = (224, 224)
BATCH_SIZE = 32  # Smaller batch size for fine-tuning
EPOCHS = 5  # Fewer epochs for fine-tuning
AUTOTUNE = tf.data.AUTOTUNE

# Path to your trained ensemble weights
ENSEMBLE_WEIGHTS_PATH = 'ensemble_model.weights.h5'

# Number of layers to unfreeze from the top of each base model
UNFREEZE_LAYERS = 50  # Adjust this: higher = more layers unfrozen

from sklearn.model_selection import train_test_split

label_cols = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
              'Enlarged_Cardiomediastinum', 'Fracture', 'Lung_Lesion', 
              'Lung_Opacity', 'No_Finding', 'Pleural_Effusion', 'Pleural_Other', 
              'Pneumonia', 'Pneumothorax', 'Support_Devices']

# Assume data_df is already loaded
print(f"Using dataframe with {len(data_df)} samples")

train_seqs, val_seqs = train_test_split(data_df, test_size=0.2, random_state=42)
train_seqs = train_seqs.reset_index(drop=True)
val_seqs = val_seqs.reset_index(drop=True)

# threshold = 20% of total samples
small_classes = [
    col for col in label_cols 
    if train_seqs[col].sum() / len(data_df) < 0.2
]

print("Small classes:", small_classes)

# Compute label weights
class_counts = data_df[label_cols].sum().values
total_samples = len(data_df)
pos_weights = np.maximum(class_counts, 1)
neg_weights = total_samples - pos_weights
label_weights = neg_weights / pos_weights
label_weights = np.clip(label_weights, 0.5, 3.0)
label_weights = tf.constant(label_weights, dtype=tf.float32)

print("Label weights:", label_weights.numpy())

# ------------------- LOSS FUNCTION -------------------
@tf.function(jit_compile=True)
def weighted_bce(y_true, y_pred):
    """Stable weighted binary crossentropy"""
    bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    weighted_bce = bce * label_weights
    return tf.reduce_mean(weighted_bce)

# ------------------- DATA PIPELINE -------------------
@tf.function
def decode_and_process_train(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE, method='bilinear')

    # Base augmentations
    img = tf.image.random_flip_left_right(img)
    if tf.random.uniform([]) < 0.5:
        k = tf.random.uniform([], -1, 2, dtype=tf.int32)
        img = tf.image.rot90(img, k=k)
    img = tf.image.random_brightness(img, 0.1)
    img = tf.image.random_contrast(img, 0.85, 1.15)

    # Extra augmentations for smaller classes
    if tf.reduce_any(tf.greater(tf.gather(label, 
        [label_cols.index(c) for c in small_classes]), 0)):
        img = tf.image.random_saturation(img, 0.8, 1.3)
        img = tf.image.random_hue(img, 0.05)
        img = tf.image.random_flip_up_down(img)
        if tf.random.uniform([]) < 0.3:
            noise = tf.random.normal(tf.shape(img), mean=0.0, stddev=5.0)
            img = tf.clip_by_value(img + noise, 0, 255)

    img = tf.keras.applications.resnet50.preprocess_input(img)
    return img, label

@tf.function
def decode_and_process_val(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE, method='bilinear')
    img = tf.keras.applications.resnet50.preprocess_input(img)
    return img, label

# Prepare data
train_paths = train_seqs['Image_path'].values
train_labels = train_seqs[label_cols].values.astype("float32")
val_paths = val_seqs['Image_path'].values
val_labels = val_seqs[label_cols].values.astype("float32")

print(f"Training samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")

# Build datasets
train_ds = (
    tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    .shuffle(3000, reshuffle_each_iteration=True)
    .map(decode_and_process_train, num_parallel_calls=AUTOTUNE, deterministic=False)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=3)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
    .map(decode_and_process_val, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=2)
)

print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(train_paths) // BATCH_SIZE}")

# ------------------- REBUILD ENSEMBLE MODEL -------------------
print("\n" + "="*50)
print("Rebuilding Ensemble Architecture...")
print("="*50)

# Mixed precision
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"Mixed precision: {policy.name}")

def build_model(base_architecture, name):
    """Build a model with the given base architecture"""
    if base_architecture == 'resnet50':
        base_model = tf.keras.applications.ResNet50(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights=None
        )
    elif base_architecture == 'efficientnet':
        base_model = tf.keras.applications.EfficientNetB3(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights=None
        )
    elif base_architecture == 'densenet':
        base_model = tf.keras.applications.DenseNet121(
            include_top=False,
            input_shape=IMG_SIZE + (3,),
            weights=None
        )
    
    # Build head
    x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    x = tf.keras.layers.Dropout(0.3)(x)
    logits = tf.keras.layers.Dense(
        len(label_cols),
        activation=None,
        dtype='float32',
        name=f'{name}_output'
    )(x)
    
    model = tf.keras.Model(inputs=base_model.input, outputs=logits, name=name)
    
    return model, base_model

# Build the three models
model_resnet, base_resnet = build_model('resnet50', 'resnet50')
model_efficientnet, base_efficientnet = build_model('efficientnet', 'efficientnet')
model_densenet, base_densenet = build_model('densenet', 'densenet')

models_list = [model_resnet, model_efficientnet, model_densenet]
base_models_list = [base_resnet, base_efficientnet, base_densenet]
model_names = ['resnet50', 'efficientnet', 'densenet']

# Input layer
input_layer = tf.keras.Input(shape=IMG_SIZE + (3,), name='ensemble_input')

# Get predictions from all three models
pred1 = model_resnet(input_layer)
pred2 = model_efficientnet(input_layer)
pred3 = model_densenet(input_layer)

# Average predictions
ensemble_output = tf.keras.layers.Average(name='ensemble_average')([pred1, pred2, pred3])

# Create ensemble model
ensemble_model = tf.keras.Model(inputs=input_layer, outputs=ensemble_output, name='ensemble')

print("✓ Ensemble architecture rebuilt")

# ------------------- LOAD TRAINED WEIGHTS -------------------
print("\n" + "="*50)
print("Loading trained ensemble weights...")
print("="*50)

if os.path.exists(ENSEMBLE_WEIGHTS_PATH):
    ensemble_model.load_weights(ENSEMBLE_WEIGHTS_PATH)
    print(f"✓ Loaded weights from: {ENSEMBLE_WEIGHTS_PATH}")
else:
    print(f"WARNING: Weights file not found at {ENSEMBLE_WEIGHTS_PATH}")
    print("Please check the path!")
    raise FileNotFoundError(f"Weights not found: {ENSEMBLE_WEIGHTS_PATH}")

# ------------------- FREEZE STRATEGY -------------------
print("\n" + "="*50)
print(f"Unfreezing top {UNFREEZE_LAYERS} layers of each base model...")
print("="*50)

def unfreeze_top_layers(base_model, num_layers_to_unfreeze):
    """Unfreeze the top N layers of a model"""
    # First freeze everything
    base_model.trainable = True
    for layer in base_model.layers:
        layer.trainable = False
    
    # Then unfreeze the last N layers
    total_layers = len(base_model.layers)
    layers_to_unfreeze = min(num_layers_to_unfreeze, total_layers)
    
    for layer in base_model.layers[-layers_to_unfreeze:]:
        layer.trainable = True
    
    trainable_count = sum([1 for layer in base_model.layers if layer.trainable])
    return trainable_count

# Unfreeze top layers for each base model
print("\nUnfreezing layers:")
for base_model, name in zip(base_models_list, model_names):
    trainable = unfreeze_top_layers(base_model, UNFREEZE_LAYERS)
    total = len(base_model.layers)
    print(f"  {name}: {trainable}/{total} layers trainable")

# The head layers (GlobalAveragePooling, Dropout, Dense) are always trainable
print("\nHead layers (GAP, Dropout, Dense) are always trainable")

# Count total trainable parameters
trainable_count = sum([tf.size(w).numpy() for w in ensemble_model.trainable_weights])
non_trainable_count = sum([tf.size(w).numpy() for w in ensemble_model.non_trainable_weights])
total_count = trainable_count + non_trainable_count

print("\n" + "="*50)
print(f"Trainable parameters: {trainable_count:,} ({100*trainable_count/total_count:.1f}%)")
print(f"Non-trainable parameters: {non_trainable_count:,}")
print(f"Total parameters: {total_count:,}")
print("="*50)

# ------------------- COMPILE WITH VERY LOW LEARNING RATE -------------------
print("\n" + "="*50)
print("Compiling for fine-tuning...")
print("="*50)

# VERY low learning rate for fine-tuning top layers
fine_tune_lr = 5e-6  # Even lower than before

optimizer = tf.keras.optimizers.AdamW(
    learning_rate=fine_tune_lr,
    weight_decay=1e-7,  # Lower weight decay
    clipnorm=1.0
)

ensemble_model.compile(
    optimizer=optimizer,
    loss=weighted_bce,
    metrics=[
        tf.keras.metrics.AUC(name="auc", from_logits=True),
        tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.5)
    ],
    jit_compile=True
)

print(f"✓ Compiled with learning rate: {fine_tune_lr:.2e}")

# ------------------- CALLBACKS -------------------
# Clean up any existing checkpoint files
checkpoint_path = 'finetuned_ensemble.weights.h5'
if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)
    print(f"Removed existing checkpoint: {checkpoint_path}")

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc',
    patience=3,
    restore_best_weights=True,
    mode='max',
    verbose=1
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=2,
    verbose=1,
    min_lr=1e-8
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_auc',
    mode='max',
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

class TimingCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        import time
        self.epoch_start = time.time()
    
    def on_epoch_end(self, epoch, logs=None):
        import time
        epoch_time = time.time() - self.epoch_start
        print(f"\nEpoch time: {epoch_time/60:.1f} minutes")

timing_cb = TimingCallback()

# ------------------- FINE-TUNE -------------------
print("\n" + "="*50)
print("Starting fine-tuning with top layers unfrozen...")
print("="*50 + "\n")

history = ensemble_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[early_stopping, reduce_lr, checkpoint, timing_cb],
    verbose=1
)

# ------------------- RESULTS -------------------
print("\n" + "="*50)
print("Fine-tuning completed!")
print("="*50)
print(f"Best validation AUC: {max(history.history['val_auc']):.4f}")
print(f"Best validation loss: {min(history.history['val_loss']):.4f}")
print(f"Final learning rate: {ensemble_model.optimizer.learning_rate.numpy():.2e}")

# Save final fine-tuned model
final_model_path = 'final_finetuned_ensemble.weights.h5'
if os.path.exists(final_model_path):
    os.remove(final_model_path)
ensemble_model.save_weights(final_model_path)
print(f"\n✓ Fine-tuned ensemble weights saved to: {final_model_path}")

# ------------------- INFERENCE FUNCTION -------------------
def predict_ensemble(image_paths):
    """
    Make predictions using the fine-tuned ensemble model
    
    Args:
        image_paths: List of image file paths
    
    Returns:
        predictions: Probabilities for each class (sigmoid applied to logits)
    """
    # Create dataset
    ds = (
        tf.data.Dataset.from_tensor_slices(image_paths)
        .map(lambda x: decode_and_process_val(x, tf.zeros(len(label_cols))))
        .batch(BATCH_SIZE)
        .prefetch(buffer_size=2)
    )
    
    # Get logits from ensemble
    logits = ensemble_model.predict(ds, verbose=1)
    
    # Convert to probabilities
    predictions = tf.sigmoid(logits).numpy()
    
    return predictions

print("\n" + "="*50)
print("Use predict_ensemble(image_paths) for inference")
print("="*50)

# Optional: Clear memory
import gc
gc.collect()
tf.keras.backend.clear_session()