In [None]:
# import os
# import shutil

# #this will classify images based on their labels into directories with corresponding class IDs
# image_dir = "dataset/compiled/images_v2"
# label_dir = "dataset/compiled/labels"
# output_dir = "dataset/dataset_classified"

# os.makedirs(output_dir, exist_ok=True)

# for img_name in os.listdir(image_dir):
#     if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
#         continue
#     label_name = os.path.splitext(img_name)[0] + ".txt"
#     label_path = os.path.join(label_dir, label_name)
#     img_path = os.path.join(image_dir, img_name)
#     if not os.path.exists(label_path):
#         print(f"Label not found for {img_name}, skipping.")
#         continue
#     with open(label_path, "r") as f:
#         first_line = f.readline().strip()
#         if not first_line:
#             print(f"Empty label for {img_name}, skipping.")
#             continue
#         class_id = first_line.split()[0]
#     class_dir = os.path.join(output_dir, class_id)
#     os.makedirs(class_dir, exist_ok=True)
#     shutil.copy(img_path, os.path.join(class_dir, img_name))

In [2]:
import os

base_dir = "dataset/dataset_classified"
for class_name in sorted(os.listdir(base_dir)):
    class_path = os.path.join(base_dir, class_name)
    if os.path.isdir(class_path):
        num_files = len([f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        print(f"Class {class_name}: {num_files} images")

Class 0: 189 images
Class 1: 475 images
Class 2: 303 images


In [None]:
import os
import shutil
import random

src_dir = "dataset/dataset_classified"
dst_dir = "dataset/dataset_classified_split"
splits = {'train': 0.7, 'val': 0.15, 'test': 0.15}

os.makedirs(dst_dir, exist_ok=True)

for class_name in os.listdir(src_dir):
    class_path = os.path.join(src_dir, class_name)
    if not os.path.isdir(class_path):
        continue
    images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    random.shuffle(images)
    n = len(images)
    n_train = int(n * splits['train'])
    n_val = int(n * splits['val'])
    split_points = [n_train, n_train + n_val]
    split_names = ['train', 'val', 'test']
    for split, (start, end) in zip(split_names, [(0, split_points[0]), (split_points[0], split_points[1]), (split_points[1], n)]):
        split_dir = os.path.join(dst_dir, split, class_name)
        os.makedirs(split_dir, exist_ok=True)
        for img_name in images[start:end]:
            shutil.copy(os.path.join(class_path, img_name), os.path.join(split_dir, img_name))

In [None]:
from tensorflow import keras

BATCH_SIZE = 32
IMG_SIZE = (224, 224)

train_ds = keras.utils.image_dataset_from_directory(
    "dataset/dataset_classified/train",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # seed=42,
    # label_mode="int"
)

val_ds = keras.utils.image_dataset_from_directory(
    "dataset/dataset_classified/val",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # seed=42,
    # label_mode="int"
)


test_ds = keras.utils.image_dataset_from_directory(
    "dataset_classified/test",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Preprocessing
data_augmentation = keras.Sequential([
    keras.layers.RandomFlip("horizontal"),
    keras.layers.RandomRotation(0.1),
    keras.layers.RandomZoom(0.1),
    keras.layers.RandomBrightness(0.1)
])
normalization_layer = keras.layers.Rescaling(1.0 / 255)

train_ds = train_ds.map(lambda x, y: (data_augmentation(normalization_layer(x)), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

In [None]:
# # 3.0 Build a finetunable CNN + Transformer model for burn classification
# from keras.applications import ResNet50
# from keras.layers import GlobalAveragePooling2D, Reshape, MultiHeadAttention, LayerNormalization, Dense, Dropout, Input, Flatten
# from keras.models import Model
# import tensorflow as tf

# # 3.1 Load ResNet50 backbone (pretrained on ImageNet)
# # Set trainable=False for initial training, then unfreeze for finetuning
# base_cnn = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# base_cnn.trainable = False  # Freeze for now
# print("ResNet50 backbone loaded. Trainable:", base_cnn.trainable)

# # 3.2 Extract features
# inputs = Input(shape=(224, 224, 3), name='input_image')
# x = base_cnn(inputs, training=False)
# print("Shape after CNN:", x.shape)

# # 3.3 Global average pooling to flatten spatial dims
# x = GlobalAveragePooling2D()(x)
# print("Shape after GlobalAvePooling2D:", x.shape)

# # 3.4 Reshape for transformer (batch, seq_len, features)
# x = Reshape((1, -1))(x)  # (batch, 1, features)
# print("Shape before transformer:", x.shape)

# # 3.5 Simple Transformer block
# x = MultiHeadAttention(num_heads=4, key_dim=64)(x, x)
# x = LayerNormalization()(x)
# print("Shape after transformer:", x.shape)

# # 3.6 Flatten and dense layers for classification
# x = Flatten()(x)
# x = Dropout(0.3)(x)
# x = Dense(128, activation='relu')(x)
# x = Dropout(0.2)(x)
# outputs = Dense(3, activation='softmax', name='output_class')(x)

# model = Model(inputs=inputs, outputs=outputs)

# # Print model summary for debugging
# print(model.summary())

# # 3.7 Compile the model
# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
#               loss='sparse_categorical_crossentropy',
#               metrics=['accuracy'])
# print("Model compiled.")

# # 3.8 Example: Print a batch of labels to check shape
# for images, labels in train_ds.take(1):
#     print("Batch images shape:", images.shape)
#     print("Batch labels shape:", labels['classes'].shape)
#     print("Sample labels:", labels['classes'][:10].numpy())
#     break

# # 3.9 Train the model (warm-up phase, CNN frozen)
# EPOCHS = 20

# def ds_for_classification(ds):
#     return ds.map(lambda x, y: (x, y['classes']))

# # Use these for training/validation
# train_ds_cls = ds_for_classification(train_ds)
# val_ds_cls = ds_for_classification(val_ds)

# for x, y in train_ds_cls.take(1):
#     print("Image batch shape:", x.shape)
#     print("Label batch shape:", y.shape)
#     print("Label sample:", y[:5].numpy())

# history = model.fit(
#     train_ds_cls,
#     validation_data=val_ds_cls,
#     epochs=EPOCHS
# )


ResNet50 backbone loaded. Trainable: False
Shape after CNN: (None, 7, 7, 2048)
Shape after GlobalAvePooling2D: (None, 2048)
Shape before transformer: (None, 1, 2048)
Shape after transformer: (None, 1, 2048)


None
Model compiled.
Batch images shape: (4, 224, 224, 3)
Batch labels shape: (4, 1)
Sample labels: [[0]
 [0]
 [0]
 [0]]
Image batch shape: (4, 224, 224, 3)
Label batch shape: (4, 1)
Label sample: [[0]
 [0]
 [0]
 [0]]
Epoch 1/20


ValueError: Argument `output` must have rank (ndim) `target.ndim - 1`. Received: target.shape=(None, None), output.shape=(None, 3)