In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import layers

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

# Data augmentation pipelines
# General augmentation
common_augmentation = keras.Sequential([
    keras.layers.RandomFlip("horizontal"),
    keras.layers.RandomRotation(0.2),
    keras.layers.RandomZoom(0.1),
    keras.layers.RandomTranslation(0.1, 0.1),
    keras.layers.RandomContrast(0.2),
])

# Stronger augmentation for minority class
strong_augmentation = keras.Sequential([
    keras.layers.RandomFlip("horizontal"),
    keras.layers.RandomRotation(0.3),
    keras.layers.RandomZoom(0.2),
    keras.layers.RandomTranslation(0.2, 0.2),
    keras.layers.RandomContrast(0.3),
    keras.layers.RandomBrightness(0.3),
    keras.layers.RandomWidth(0.2),
    keras.layers.RandomHeight(0.2),
])

# Normalization layer
normalization_layer = keras.layers.Rescaling(1./255)

# Load data
def preprocess_dataset(dataset, class_id):
    if class_id == 2:  # Apply stronger augmentation to minority class
        return dataset.map(lambda x, y: (strong_augmentation(normalization_layer(x)), y))
    else:  # Apply common augmentation to other classes
        return dataset.map(lambda x, y: (common_augmentation(normalization_layer(x)), y))

train_ds = keras.utils.image_dataset_from_directory(
    "dataset-v2/dataset_classified_split/train",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    color_mode="rgb"
)
val_ds = keras.utils.image_dataset_from_directory(
    "dataset-v2/dataset_classified_split/val",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False,
    color_mode="rgb"
)
test_ds = keras.utils.image_dataset_from_directory(
    "dataset-v2/dataset_classified_split/test",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False,
    color_mode="rgb"
)

# Apply class-specific augmentation
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

Found 780 files belonging to 3 classes.
Found 196 files belonging to 3 classes.
Found 196 files belonging to 3 classes.
Found 245 files belonging to 3 classes.
Found 245 files belonging to 3 classes.


In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.applications import ResNet50

# Build Hybrid ResNet + ViT model for burn classification
def build_resnet_vit(input_shape=(224, 224, 3), num_classes=3, embed_dim=256, num_heads=4, transformer_layers=2, mlp_dim=512, dropout=0.3):
    inputs = Input(shape=input_shape)
    # CNN Backbone (ResNet50, no top)
    resnet = ResNet50(include_top=False, weights='imagenet', input_shape=input_shape)
    resnet.trainable = False
    x = resnet(inputs)
    # Flatten spatial grid to sequence of patches
    patches = layers.Reshape((49, 2048))(x)
    # Linear projection to transformer dimension
    x = layers.Dense(embed_dim)(patches)
    # Add positional encoding
    positions = tf.range(start=0, limit=49, delta=1)
    pos_embed = layers.Embedding(input_dim=49, output_dim=embed_dim)(positions)
    x = x + tf.reshape(pos_embed, [1, 49, embed_dim])
    # Transformer encoder blocks
    for _ in range(transformer_layers):
        # Layer norm
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Multi-head attention
        attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)(x1, x1)
        x2 = layers.Add()([attn, x])
        # Layer norm
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP
        mlp = layers.Dense(mlp_dim, activation='gelu')(x3)
        mlp = layers.Dropout(dropout)(mlp)
        mlp = layers.Dense(embed_dim)(mlp)
        mlp = layers.Dropout(dropout)(mlp)
        x = layers.Add()([x2, mlp])
    # Global average pooling over sequence
    x = layers.GlobalAveragePooling1D()(x)
    # Classification head
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return Model(inputs, outputs)

# Build the model
model = build_resnet_vit()
model.summary()
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)


In [4]:
import numpy as np

# MixUp augmentation
@tf.function
def mixup(batch_x, batch_y, alpha=0.2):
    batch_size = tf.shape(batch_x)[0]
    lambda_val = tf.random.uniform([], 0, 1)
    index = tf.random.shuffle(tf.range(batch_size))

    mixed_x = lambda_val * batch_x + (1 - lambda_val) * tf.gather(batch_x, index)
    mixed_y = lambda_val * tf.cast(batch_y, tf.float32) + (1 - lambda_val) * tf.cast(tf.gather(batch_y, index), tf.float32)
    return mixed_x, mixed_y

# CutMix augmentation
@tf.function
def cutmix(batch_x, batch_y, alpha=1.0):
    batch_size = tf.shape(batch_x)[0]
    lambda_val = tf.random.uniform([], 0, 1)
    index = tf.random.shuffle(tf.range(batch_size))

    height, width = tf.shape(batch_x)[1], tf.shape(batch_x)[2]

    # Calculate cut size
    cut_ratio = tf.sqrt(1 - lambda_val)
    cut_w = tf.cast(tf.cast(width, tf.float32) * cut_ratio, tf.int32)
    cut_h = tf.cast(tf.cast(height, tf.float32) * cut_ratio, tf.int32)

    # Random center point
    cx = tf.random.uniform([], cut_w // 2, width - cut_w // 2, dtype=tf.int32)
    cy = tf.random.uniform([], cut_h // 2, height - cut_h // 2, dtype=tf.int32)

    # Create mask
    y_indices = tf.range(height)
    x_indices = tf.range(width)
    yy, xx = tf.meshgrid(y_indices, x_indices, indexing='ij')

    mask = tf.logical_and(
        tf.logical_and(yy >= cy - cut_h // 2, yy < cy + cut_h // 2),
        tf.logical_and(xx >= cx - cut_w // 2, xx < cx + cut_w // 2)
    )
    mask = tf.cast(mask, tf.float32)
    mask = tf.expand_dims(mask, 0)  # Add batch dimension
    mask = tf.expand_dims(mask, -1)  # Add channel dimension
    mask = tf.tile(mask, [batch_size, 1, 1, 3])  # Broadcast to all batches and channels

    mixed_x = batch_x * (1 - mask) + tf.gather(batch_x, index) * mask
    mixed_y = lambda_val * tf.cast(batch_y, tf.float32) + (1 - lambda_val) * tf.cast(tf.gather(batch_y, index), tf.float32)

    return mixed_x, mixed_y

def apply_combined_augmentation(x, y):
    x = common_augmentation(x)
    rand_val = tf.random.uniform([])
    if rand_val > 0.7:  # 30% chance for mixup
        return mixup(x, y)
    elif rand_val > 0.4:  # 30% chance for cutmix
        return cutmix(x, y)
    else:  # 40% chance for no additional augmentation
        return x, tf.cast(y, tf.float32)  # Cast y to float32 for consistency

from sklearn.utils.class_weight import compute_class_weight

# Compute class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1, 2]),  # Convert to numpy array
    y=np.array([0] * 339 + [1] * 312 + [2] * 129)  # Convert to numpy array
)
class_weights = {i: weight for i, weight in enumerate(class_weights)}

# Callbacks
callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]
EPOCHS = 30

In [5]:
# Start with simpler augmentation - just basic augmentation without mixup/cutmix
# Apply basic augmentation to training data
train_ds_augmented = train_ds.map(lambda x, y: (common_augmentation(x), y))

# Train the model
history = model.fit(
    train_ds_augmented,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks
)

# Evaluate the model
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test accuracy: {test_acc:.4f}")

model.save_weights('checkpoints/cnn-transformer/stage1_frozen_backbone.weights.h5')

print("Saved Stage 1 weights.")

Epoch 1/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - accuracy: 0.3205 - loss: 2.0351 - val_accuracy: 0.4388 - val_loss: 1.0630 - learning_rate: 0.0010
Epoch 2/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - accuracy: 0.3205 - loss: 2.0351 - val_accuracy: 0.4388 - val_loss: 1.0630 - learning_rate: 0.0010
Epoch 2/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 1s/step - accuracy: 0.3179 - loss: 1.2155 - val_accuracy: 0.1633 - val_loss: 1.1289 - learning_rate: 0.0010
Epoch 3/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 1s/step - accuracy: 0.3179 - loss: 1.2155 - val_accuracy: 0.1633 - val_loss: 1.1289 - learning_rate: 0.0010
Epoch 3/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 1s/step - accuracy: 0.3000 - loss: 1.1693 - val_accuracy: 0.3980 - val_loss: 1.1073 - learning_rate: 0.0010
Epoch 4/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 1s/s

In [6]:
# Stage 2: Fine-tuning the ResNet backbone
# Unfreeze the last 30 layers of the ResNet backbone
for layer in model.layers[1].layers[-30:]:
    layer.trainable = True

# Recompile the model with a lower learning rate for fine-tuning
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model with the updated configuration
history_stage2 = model.fit(
    train_ds_augmented,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=callbacks
)

# Evaluate the model after fine-tuning
test_loss_stage2, test_acc_stage2 = model.evaluate(test_ds)
print(f"Stage 2 Test accuracy: {test_acc_stage2:.4f}")

model.save_weights('checkpoints/cnn-transformer/stage2_fine_tuned.weights.h5')
print("Saved Stage 2 weights.")

Epoch 1/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - accuracy: 0.3590 - loss: 1.2203 - val_accuracy: 0.4388 - val_loss: 1.0625 - learning_rate: 1.0000e-05
Epoch 2/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - accuracy: 0.3590 - loss: 1.2203 - val_accuracy: 0.4388 - val_loss: 1.0625 - learning_rate: 1.0000e-05
Epoch 2/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 2s/step - accuracy: 0.3987 - loss: 1.1015 - val_accuracy: 0.4388 - val_loss: 1.0547 - learning_rate: 1.0000e-05
Epoch 3/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 2s/step - accuracy: 0.3987 - loss: 1.1015 - val_accuracy: 0.4388 - val_loss: 1.0547 - learning_rate: 1.0000e-05
Epoch 3/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 2s/step - accuracy: 0.4167 - loss: 1.0798 - val_accuracy: 0.4439 - val_loss: 1.0496 - learning_rate: 1.0000e-05
Epoch 4/30
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

In [7]:
# Unfreeze the last 50% of ResNet backbone layers for progressive fine-tuning
resnet_layers = model.layers[1].layers
num_layers = len(resnet_layers)
num_to_unfreeze = num_layers // 2  # Unfreeze last 50%
for layer in resnet_layers[-num_to_unfreeze:]:
    if not layer.trainable:
        layer.trainable = True

# Recompile the model with a very low learning rate for fine-tuning
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),  # Very low LR for stability
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# One-hot encode the labels for advanced augmentation
# (MixUp and CutMix require one-hot labels)
def one_hot_encode(x, y):
    return x, tf.one_hot(y, depth=3)

train_ds_one_hot = train_ds.map(one_hot_encode)
val_ds_one_hot = val_ds.map(one_hot_encode)
test_ds_one_hot = test_ds.map(one_hot_encode)

# Apply the combined augmentation
train_ds_stage3 = train_ds_one_hot.map(apply_combined_augmentation)

# Switch to categorical_crossentropy for one-hot labels
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model with the updated configuration
print("Starting Stage 3: Progressive fine-tuning (last 50% of ResNet) with advanced augmentations.")
history_stage3 = model.fit(
    train_ds_stage3,
    validation_data=val_ds_one_hot,
    epochs=50,
    class_weight=class_weights,
    callbacks=callbacks
)

# Evaluate the model after Stage 3
test_loss_stage3, test_acc_stage3 = model.evaluate(test_ds_one_hot)
print(f"Stage 3 Test accuracy: {test_acc_stage3:.4f}")

model.save_weights('checkpoints/cnn-transformer/stage3_progressive_fine_tuned.weights.h5')
print("Saved Stage 3 weights.")

Starting Stage 3: Progressive fine-tuning (last 50% of ResNet) with advanced augmentations.
Epoch 1/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 2s/step - accuracy: 0.3936 - loss: 1.1873 - val_accuracy: 0.4439 - val_loss: 1.0542 - learning_rate: 1.0000e-05
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 2s/step - accuracy: 0.3936 - loss: 1.1873 - val_accuracy: 0.4439 - val_loss: 1.0542 - learning_rate: 1.0000e-05
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 2s/step - accuracy: 0.3821 - loss: 1.1379 - val_accuracy: 0.4439 - val_loss: 1.0373 - learning_rate: 1.0000e-05
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 2s/step - accuracy: 0.3821 - loss: 1.1379 - val_accuracy: 0.4439 - val_loss: 1.0373 - learning_rate: 1.0000e-05
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m57s[0m 2s/step - accuracy: 0.4038 - loss: 1.0819 - val_accuracy: 0.4439 - val_loss: 1.0

In [8]:
# Stage 4: Progressive fine-tuning (unfreeze last 85% of ResNet layers)
resnet_layers = model.layers[1].layers
num_layers = len(resnet_layers)
num_to_unfreeze = int(num_layers * 0.85)  # Unfreeze last 85%
for layer in resnet_layers[-num_to_unfreeze:]:
    if not layer.trainable:
        layer.trainable = True

# Recompile the model with a very low learning rate for fine-tuning
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Apply the combined augmentation (MixUp/CutMix) to the training data
train_ds_stage4 = train_ds.map(apply_combined_augmentation)

# Train the model with the updated configuration
print("Starting Stage 4: Progressive fine-tuning (last 85% of ResNet) with advanced augmentations.")
history_stage4 = model.fit(
    train_ds_stage4,
    validation_data=val_ds,
    epochs=50,
    class_weight=class_weights,
    callbacks=callbacks
)

# Evaluate the model after Stage 4
test_loss_stage4, test_acc_stage4 = model.evaluate(test_ds)
print(f"Stage 4 Test accuracy: {test_acc_stage4:.4f}")

Starting Stage 4: Progressive fine-tuning (last 85% of ResNet) with advanced augmentations.
Epoch 1/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 4s/step - accuracy: 0.2231 - loss: 1.1141 - val_accuracy: 0.4388 - val_loss: 1.0689 - learning_rate: 1.0000e-05
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m126s[0m 4s/step - accuracy: 0.2231 - loss: 1.1141 - val_accuracy: 0.4388 - val_loss: 1.0689 - learning_rate: 1.0000e-05
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 4s/step - accuracy: 0.3231 - loss: 1.0830 - val_accuracy: 0.4388 - val_loss: 1.0380 - learning_rate: 1.0000e-05
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 4s/step - accuracy: 0.3231 - loss: 1.0830 - val_accuracy: 0.4388 - val_loss: 1.0380 - learning_rate: 1.0000e-05
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 4s/step - accuracy: 0.2526 - loss: 1.0071 - val_accuracy: 0.4337 - val_loss: 