In [None]:
import os

# Set the path to images
image_dir = 'dataset/compiled-old/images'

# Get list of image files (filtering out non-files and hidden files)
image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f)) and not f.startswith('.')]

# Print total number of images
print(f"Total number of old images: {len(image_files)}")

Total number of old images: 9585


In [None]:
# Set the path to images
image_dir = 'dataset/compiled/images'

# Get list of image files (filtering out non-files and hidden files)
image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f)) and not f.startswith('.')]

# Print total number of images
print(f"Total number of new images: {len(image_files)}")

Total number of new images: 967


almost 90% drop on dataset based on previous attempts. will limit data cleaning for better results.

In [7]:
import imagehash
import pandas as pd
from PIL import Image
import os

image_dir = 'dataset/compiled-new/images'
label_dir = 'dataset/compiled-new/labels'

def find_duplicates(image_dir, label_dir, hash_size=8):
    """Returns a list of duplicate image pairs with their annotations."""
    hashes = {}
    duplicates = []

    # Supported image extensions
    image_exts = ('.jpg', '.jpeg', '.png', '.webp')

    for img_file in os.listdir(image_dir):
        if not img_file.lower().endswith(image_exts):
            continue

        img_path = os.path.join(image_dir, img_file)
        label_file = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(label_dir, label_file)

        try:
            # Calculate perceptual hash
            with Image.open(img_path) as img:
                img_hash = imagehash.average_hash(img, hash_size=hash_size)

            # Check for duplicates
            if img_hash in hashes:
                duplicates.append((img_path, label_path, hashes[img_hash]))
            else:
                hashes[img_hash] = (img_path, label_path)
        except Exception as e:
            print(f"Error processing {img_file}: {str(e)}")

    return duplicates

def handle_duplicates(duplicates, strategy='keep_best_annotation'):
    """
    More robust duplicate handler that:
    1. Checks if files exist before processing
    2. Provides detailed error reporting
    3. Handles edge cases
    """
    deleted_files = []
    errors = []

    for img1, label1, (img2, label2) in duplicates:
        try:
            # Verify all files exist
            existing_files = []
            for path in [img1, label1, img2, label2]:
                if os.path.exists(path):
                    existing_files.append(path)
                else:
                    errors.append(f"Missing file: {path}")

            # Skip if critical files are missing
            if len(existing_files) < 4:
                continue

            # Determine which pair to keep
            if strategy == 'keep_first':
                to_delete = (img2, label2)
            elif strategy == 'keep_best_annotation':
                # Safely compare annotations
                count1, count2 = 0, 0
                try:
                    with open(label1, 'r') as f1:
                        count1 = len([line for line in f1 if line.strip()])
                    with open(label2, 'r') as f2:
                        count2 = len([line for line in f2 if line.strip()])
                except Exception as e:
                    errors.append(f"Error reading annotations: {str(e)}")
                    continue

                to_delete = (img2, label2) if count1 >= count2 else (img1, label1)
            elif strategy == 'keep_higher_res':
                try:
                    size1 = os.path.getsize(img1)
                    size2 = os.path.getsize(img2)
                except Exception as e:
                    errors.append(f"Error getting file sizes: {str(e)}")
                    continue
                to_delete = (img2, label2) if size1 >= size2 else (img1, label1)

            # Delete files (with existence check)
            for file in to_delete:
                if os.path.exists(file):
                    os.remove(file)
                    deleted_files.append(file)
                else:
                    errors.append(f"Tried to delete non-existent file: {file}")

        except Exception as e:
            errors.append(f"Error processing pair {img1} vs {img2}: {str(e)}")

    # Print summary
    print(f"Deleted {len(deleted_files)//2} duplicate pairs")
    if errors:
        print("\nEncountered errors:")
        for error in errors[:5]:  # Show first 5 errors
            print(f" - {error}")
        if len(errors) > 5:
            print(f" - ...and {len(errors)-5} more errors")

    return deleted_files, errors

# Step 1: Find all duplicates
duplicates = find_duplicates(image_dir, label_dir)

# Step 3: Handle duplicates (choose strategy)
handle_duplicates(duplicates, strategy='keep_best_annotation')

# Step 4: Verify remaining files
print(f"Remaining images: {len(os.listdir(image_dir))}")
print(f"Remaining labels: {len(os.listdir(label_dir))}")


Deleted 24 duplicate pairs

Encountered errors:
 - Missing file: dataset/compiled-new/images\v3_img1414_jpg.rf.579eb39c366bb48651ad5a7216b9b843.jpg
 - Missing file: dataset/compiled-new/labels\v3_img1414_jpg.rf.579eb39c366bb48651ad5a7216b9b843.txt
 - Missing file: dataset/compiled-new/images\v3_img209_jpg.rf.69418453eb5bc826a35acf8bd440ecdc.jpg
 - Missing file: dataset/compiled-new/labels\v3_img209_jpg.rf.69418453eb5bc826a35acf8bd440ecdc.txt
 - Missing file: dataset/compiled-new/images\v3_img209_jpg.rf.92b0690ceeb149945e5fd1cee86eafd5.jpg
 - ...and 167 more errors
Remaining images: 9561
Remaining labels: 9561


In [9]:
import os
from collections import defaultdict

def validate_pairs(image_dir, label_dir):
    """Comprehensive validation of image-label pairs with detailed reporting"""
    # Get all files (case-insensitive)
    image_files = defaultdict(list)
    label_files = set()

    # Supported image extensions
    img_exts = ('.jpg', '.jpeg', '.png', '.webp')

    # Scan image directory
    for f in os.listdir(image_dir):
        base, ext = os.path.splitext(f)
        if ext.lower() in img_exts:
            image_files[base.lower()].append(f)  # Store with original case

    # Scan label directory
    for f in os.listdir(label_dir):
        if f.lower().endswith('.txt'):
            label_files.add(os.path.splitext(f)[0].lower())

    # Find mismatches
    missing_labels = set(image_files.keys()) - label_files
    missing_images = label_files - set(image_files.keys())

    # Generate detailed report
    report = {
        'images_without_labels': [],
        'labels_without_images': [],
        'multiple_images': []
    }

    # Check for images without labels
    for base in missing_labels:
        for img_file in image_files[base]:
            report['images_without_labels'].append({
                'image': img_file,
                'possible_label': f"{os.path.splitext(img_file)[0]}.txt"
            })

    # Check for labels without images
    for base in missing_images:
        report['labels_without_images'].append({
            'label': f"{base}.txt",
            'possible_images': [f"{base}{ext}" for ext in img_exts]
        })

    # Check for multiple image extensions
    for base, files in image_files.items():
        if len(files) > 1:
            report['multiple_images'].append({
                'base': base,
                'files': files,
                'label_exists': base in label_files
            })

    return report

def print_report(report):
    """Print a human-readable validation report"""
    print("\n=== Dataset Validation Report ===")

    # Section 1: Images without labels
    print(f"\n[!] {len(report['images_without_labels'])} images missing labels:")
    for item in report['images_without_labels'][:5]:  # Show first 5 examples
        print(f"  - Image: {item['image']}")
        print(f"    Expected label: {item['possible_label']}")
    if len(report['images_without_labels']) > 5:
        print(f"    (...and {len(report['images_without_labels']) - 5} more)")

    # Section 2: Labels without images
    print(f"\n[!] {len(report['labels_without_images'])} labels missing images:")
    for item in report['labels_without_images'][:5]:
        print(f"  - Label: {item['label']}")
        print(f"    Expected image variants: {', '.join(item['possible_images'])}")
    if len(report['labels_without_images']) > 5:
        print(f"    (...and {len(report['labels_without_images']) - 5} more)")

    # Section 3: Multiple image extensions
    if report['multiple_images']:
        print(f"\n[!] {len(report['multiple_images'])} base names with multiple images:")
        for item in report['multiple_images'][:3]:
            print(f"  - Base: {item['base']}")
            print(f"    Files: {', '.join(item['files'])}")
            print(f"    Has label: {'Yes' if item['label_exists'] else 'No'}")
        if len(report['multiple_images']) > 3:
            print(f"    (...and {len(report['multiple_images']) - 3} more)")


report = validate_pairs(image_dir, label_dir)
print_report(report)

# Optional: Auto-clean with confirmation
if input("\nClean up missing files? (y/n): ").lower() == 'y':
    # Clean images without labels
    for item in report['images_without_labels']:
        img_path = os.path.join(image_dir, item['image'])
        os.remove(img_path)
        print(f"Removed {img_path}")

    # Clean labels without images
    for item in report['labels_without_images']:
        label_path = os.path.join(label_dir, item['label'])
        if os.path.exists(label_path):
            os.remove(label_path)
            print(f"Removed {label_path}")

    print("Cleanup complete!")
else:
    print("No files were deleted")


=== Dataset Validation Report ===

[!] 0 images missing labels:

[!] 0 labels missing images:
Cleanup complete!


In [14]:
# Set the path to images
image_dir = 'dataset/compiled-new/images'

# Get list of image files (filtering out non-files and hidden files)
image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f)) and not f.startswith('.')]

# Print total number of images
print(f"Total number of images: {len(image_files)}")

Total number of images: 9561


In [11]:
import os
from collections import Counter

# Directories
label_dir = 'dataset/compiled-new/labels'

# Initialize class counter
class_counts = Counter()

# Go through each label file
for label_file in os.listdir(label_dir):
    if label_file.endswith('.txt'):
        with open(os.path.join(label_dir, label_file), 'r') as f:
            lines = f.readlines()
            for line in lines:
                class_id = line.strip().split()[0]  # get the class ID
                class_counts[class_id] += 1

# Map class IDs to degree names
class_map = {'0': 'First Degree Burn', '1': 'Second Degree Burn', '2': 'Third Degree Burn'}

# Print distribution
print("Class Distribution:")
for class_id, count in class_counts.items():
    print(f"{class_map.get(class_id, 'Unknown')} (Class {class_id}): {count} images")

# Optional: show total images with labels
total_labeled = sum(class_counts.values())
print(f"\nTotal labeled images: {total_labeled}")


Class Distribution:
First Degree Burn (Class 0): 5047 images
Third Degree Burn (Class 2): 2126 images
Second Degree Burn (Class 1): 5657 images

Total labeled images: 12830


In [16]:
import os
from collections import Counter

label_dir = 'dataset/compiled-new/labels'
class_counts = Counter()

for label_file in os.listdir(label_dir):
    if label_file.endswith('.txt'):
        with open(os.path.join(label_dir, label_file), 'r') as f:
            first_line = f.readline()
            if first_line:
                class_id = first_line.strip().split()[0]
                class_counts[class_id] += 1

class_map = {'0': 'First Degree Burn', '1': 'Second Degree Burn', '2': 'Third Degree Burn'}

print("Class Distribution (one label per image):")
for class_id, count in class_counts.items():
    print(f"{class_map.get(class_id, 'Unknown')} (Class {class_id}): {count} images")

total_labeled = sum(class_counts.values())
print(f"\nTotal labeled images: {total_labeled}")

Class Distribution (one label per image):
First Degree Burn (Class 0): 3925 images
Third Degree Burn (Class 2): 1564 images
Second Degree Burn (Class 1): 4072 images

Total labeled images: 9561


In [15]:
# Check how many label files have no corresponding image
label_dir = 'dataset/compiled-new/labels'
image_stems = set(os.path.splitext(f)[0] for f in image_files)
label_stems = set(os.path.splitext(f)[0] for f in os.listdir(label_dir) if f.endswith('.txt'))

no_image = label_stems - image_stems
print(f"Label files with no matching image: {len(no_image)}")
if no_image:
    print("Example:", list(no_image)[:5])


Label files with no matching image: 0


In [17]:
import os
import shutil
from sklearn.model_selection import train_test_split

# Paths
image_dir = "dataset/compiled-new/images"
label_dir = "dataset/compiled-new/labels"
output_base = "dataset/dataset_classified_split"

# Gather (image, class_id) pairs
samples = []
for img_name in os.listdir(image_dir):
    if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
        continue
    label_name = os.path.splitext(img_name)[0] + ".txt"
    label_path = os.path.join(label_dir, label_name)
    img_path = os.path.join(image_dir, img_name)
    if not os.path.exists(label_path):
        continue
    with open(label_path, "r") as f:
        first_line = f.readline().strip()
        if not first_line:
            continue
        class_id = first_line.split()[0]
    samples.append((img_path, class_id))

# Stratified split
img_paths, class_ids = zip(*samples)
train_imgs, test_imgs, train_labels, test_labels = train_test_split(
    img_paths, class_ids, test_size=0.2, stratify=class_ids, random_state=42
)
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
    train_imgs, train_labels, test_size=0.2, stratify=train_labels, random_state=42
)

splits = [
    ("train", train_imgs, train_labels),
    ("val", val_imgs, val_labels),
    ("test", test_imgs, test_labels),
]

# Copy images into split/class folders
for split_name, imgs, labels in splits:
    for img_path, class_id in zip(imgs, labels):
        split_dir = os.path.join(output_base, split_name, class_id)
        os.makedirs(split_dir, exist_ok=True)
        shutil.copy(img_path, os.path.join(split_dir, os.path.basename(img_path)))

In [18]:
from collections import Counter

print("Final Split Distribution:")
for name, labels in [("Train", train_labels), ("Val", val_labels), ("Test", test_labels)]:
    counts = Counter(labels)
    print(f"{name}: {dict(counts)}")

Final Split Distribution:
Train: {'0': 2512, '1': 2605, '2': 1001}
Val: {'0': 628, '2': 250, '1': 652}
Test: {'2': 313, '0': 785, '1': 815}


In [1]:
import tensorflow as tf
from tensorflow import keras

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

#1.1 Data augmentation pipeline
data_augmentation = keras.Sequential([
    keras.layers.RandomFlip("horizontal"),
    keras.layers.RandomRotation(0.1),
    keras.layers.RandomZoom(0.1),
    keras.layers.RandomBrightness(0.1)
])

# Normalization layer
normalization_layer = keras.layers.Rescaling(1./255)

In [2]:
#1.2 load data

train_ds = keras.utils.image_dataset_from_directory(
    "dataset/dataset_classified_split/train",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True
)
val_ds = keras.utils.image_dataset_from_directory(
    "dataset/dataset_classified_split/val",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)
test_ds = keras.utils.image_dataset_from_directory(
    "dataset/dataset_classified_split/test",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Apply augmentation and normalization
train_ds = train_ds.map(lambda x, y: (data_augmentation(normalization_layer(x)), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

Found 6118 files belonging to 3 classes.
Found 1530 files belonging to 3 classes.
Found 1530 files belonging to 3 classes.
Found 1913 files belonging to 3 classes.
Found 1913 files belonging to 3 classes.


In [None]:
# from tensorflow.keras import layers

# #1.3 build model - pretrained resnet50
# def build_resnet_classifier(input_shape=IMG_SIZE + (3,), num_classes=3):
#     base_model = keras.applications.ResNet50(
#         include_top=False,
#         weights="imagenet",
#         input_shape=input_shape,
#         pooling="avg"
#     )
#     base_model.trainable = False  # Fine-tune later if needed

#     inputs = keras.Input(shape=input_shape)
#     x = base_model(inputs, training=False)
#     x = layers.Dropout(0.3)(x)
#     x = layers.Dense(128, activation="relu")(x)
#     x = layers.Dropout(0.3)(x)
#     outputs = layers.Dense(num_classes, activation="softmax")(x)
#     model = keras.Model(inputs, outputs)
#     return model

# resnet_model = build_resnet_classifier()
# resnet_model.compile(
#     optimizer=keras.optimizers.Adam(),
#     loss="sparse_categorical_crossentropy",
#     metrics=["accuracy"]
# )
# resnet_model.summary()

#result = 45% accuracy after 10 epochs

In [None]:
# from tensorflow.keras.applications import ResNet50
# from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, GlobalAveragePooling2D, Reshape, Dense, Input, Dropout, Flatten
# from tensorflow.keras.models import Model

# inputs = Input(shape=(224, 224, 3))
# # CNN backbone
# resnet = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# resnet.trainable = False
# x = resnet(inputs)

# # Convert CNN features to 1D
# x = GlobalAveragePooling2D()(x)         # (batch, 2048)
# x = Reshape((1, -1))(x)                 # (batch, 1, 2048)

# # Transformer block (self-attention)
# x = MultiHeadAttention(num_heads=4, key_dim=64)(x, x)
# x = LayerNormalization()(x)
# x = Flatten()(x)
# x = Dropout(0.3)(x)
# x = Dense(128, activation='relu')(x)
# x = Dropout(0.3)(x)
# outputs = Dense(3, activation='softmax')(x)

# model = Model(inputs=inputs, outputs=outputs)
# model.compile(
#     optimizer=keras.optimizers.Adam(),
#     loss="sparse_categorical_crossentropy",
#     metrics=["accuracy"]
# )
# model.summary()

#result 42% accuracy after 10 epochs
#was only able to classify class 1

In [7]:
# EPOCHS = 10
# history = model.fit(
#     train_ds,
#     validation_data=val_ds,
#     epochs=EPOCHS
# )

# test_loss, test_acc = model.evaluate(test_ds)
# print(f"Test accuracy: {test_acc:.3f}")

In [10]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, GlobalAveragePooling2D, Reshape, Dense, Input, Dropout, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Compute class weights for imbalance
class_counts = [3925, 4072, 1564]  # [class 0, class 1, class 2]
class_labels = np.array([0]*class_counts[0] + [1]*class_counts[1] + [2]*class_counts[2])
class_weights = compute_class_weight('balanced', classes=np.unique(class_labels), y=class_labels)
class_weight_dict = {i: w for i, w in enumerate(class_weights)}
print("Class weights:", class_weight_dict)

inputs = Input(shape=(224, 224, 3))
# CNN backbone
resnet = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
resnet.trainable = True
for layer in resnet.layers[:-10]:
    layer.trainable = False
x = resnet(inputs)
x = GlobalAveragePooling2D()(x)
x = Reshape((1, -1))(x)
x = MultiHeadAttention(num_heads=4, key_dim=64)(x, x)
x = LayerNormalization()(x)
x = Flatten()(x)
x = Dropout(0.3)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.3)(x)
outputs = Dense(3, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)
model.summary()

# Callbacks for better training
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6, verbose=1)

EPOCHS = 30
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[early_stop, reduce_lr],
    class_weight=class_weight_dict
)

Class weights: {0: np.float64(0.8119745222929936), 1: np.float64(0.7826620825147348), 2: np.float64(2.03772378516624)}


Epoch 1/30
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m296s[0m 2s/step - accuracy: 0.3601 - loss: 1.1377 - val_accuracy: 0.1634 - val_loss: 1.2019 - learning_rate: 1.0000e-05
Epoch 2/30
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m284s[0m 1s/step - accuracy: 0.3656 - loss: 1.0841 - val_accuracy: 0.4261 - val_loss: 1.0783 - learning_rate: 1.0000e-05
Epoch 3/30
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m289s[0m 2s/step - accuracy: 0.3624 - loss: 1.0853 - val_accuracy: 0.2072 - val_loss: 1.1337 - learning_rate: 1.0000e-05
Epoch 4/30
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.3588 - loss: 1.0709
Epoch 4: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m310s[0m 2s/step - accuracy: 0.3604 - loss: 1.0778 - val_accuracy: 0.3686 - val_loss: 1.1082 - learning_rate: 1.0000e-05
Epoch 5/30
[1m192/192[0m [32m━━━━━━━━━━━━━━━━━━━━