In [5]:
import os
import shutil
import random

In [6]:
source_base_dir = "../../dataset"
destination_dir = "../../2_classes_merged_dataset"

class_normal = "NORMAL"
class_pneumonia = "PNEUMONIA"

random.seed(42)

In [7]:
def balance_classes(normal_dir, pneumonia_dir):
    """
    Balances the number of samples in NORMAL and PNEUMONIA directories
    by reducing each to the size of the smallest class.
    """
    # Count files in each class
    counts = {
        "NORMAL": len(os.listdir(normal_dir)),
        "PNEUMONIA": len(os.listdir(pneumonia_dir))
    }

    print("\nInitial class counts:")
    for cls, cnt in counts.items():
        print(f"  {cls}: {cnt} samples")

    # Determine the minimum class size
    min_count = min(counts.values())
    print(f"\nBalancing each class to {min_count} samples.")

    # Function to balance a single class
    def balance_class(class_dir, target_count):
        files = os.listdir(class_dir)
        if len(files) > target_count:
            # Randomly sample files to remove
            files_to_remove = random.sample(files, len(files) - target_count)
            for file_name in files_to_remove:
                file_path = os.path.join(class_dir, file_name)
                os.remove(file_path)
            print(f"Balanced {os.path.basename(class_dir)}: {len(os.listdir(class_dir))} samples")
        else:
            print(f"No balancing needed for {os.path.basename(class_dir)}.")

    # Balance each class
    balance_class(normal_dir, min_count)
    balance_class(pneumonia_dir, min_count)

    print("\nPost-balancing class counts:")
    for cls, dir_path in [("NORMAL", normal_dir), ("PNEUMONIA", pneumonia_dir)]:
        count = len(os.listdir(dir_path))
        print(f"  {cls}: {count} samples")


def merge_datasets(source_dir, dest_dir):
    """
    Merges images from train, test, and val splits into a single directory with
    NORMAL and PNEUMONIA subdirectories.
    """
    splits = ["train", "test", "val"]

    # Create destination class directories
    dest_normal_path = os.path.join(dest_dir, class_normal)
    dest_pneumonia_path = os.path.join(dest_dir, class_pneumonia)

    os.makedirs(dest_normal_path, exist_ok=True)
    os.makedirs(dest_pneumonia_path, exist_ok=True)

    print(f"\nCreated/Verified destination directories in '{dest_dir}'.")

    # Iterate over each split
    for split in splits:
        split_dir = os.path.join(source_dir, split)
        if not os.path.isdir(split_dir):
            print(f"Warning: Split directory '{split_dir}' does not exist. Skipping.")
            continue

        # Paths for NORMAL and PNEUMONIA in current split
        normal_split_dir = os.path.join(split_dir, class_normal)
        pneumonia_split_dir = os.path.join(split_dir, class_pneumonia)

        # Check if NORMAL and PNEUMONIA directories exist
        if not os.path.isdir(normal_split_dir):
            print(f"Warning: Directory '{normal_split_dir}' does not exist. Skipping NORMAL class for split '{split}'.")
        if not os.path.isdir(pneumonia_split_dir):
            print(f"Warning: Directory '{pneumonia_split_dir}' does not exist. Skipping PNEUMONIA class for split '{split}'.")

        # -------------------
        # 2.1) COPY NORMAL FILES
        # -------------------
        if os.path.isdir(normal_split_dir):
            normal_files = os.listdir(normal_split_dir)
            for file_name in normal_files:
                src_path = os.path.join(normal_split_dir, file_name)
                dest_path = os.path.join(dest_normal_path, file_name)
                shutil.copy2(src_path, dest_path)
            print(f"Copied {len(normal_files)} NORMAL images from split '{split}'.")

        # -------------------
        # 2.2) COPY PNEUMONIA FILES
        # -------------------
        if os.path.isdir(pneumonia_split_dir):
            pneumonia_files = os.listdir(pneumonia_split_dir)
            copied_count = len(pneumonia_files)
            for file_name in pneumonia_files:
                src_path = os.path.join(pneumonia_split_dir, file_name)
                dest_path = os.path.join(dest_pneumonia_path, file_name)
                shutil.copy2(src_path, dest_path)
            print(f"Copied {copied_count} PNEUMONIA images from split '{split}'.")

    print("\nAll splits have been merged into '2_classes_merged_dataset'.")


def print_class_counts(base_dir):
    """
    Prints the number of samples in each class within the base directory.
    """
    print("\nFinal class counts in '2_classes_merged_dataset':")
    for class_name in [class_normal, class_pneumonia]:
        class_dir = os.path.join(base_dir, class_name)
        if os.path.isdir(class_dir):
            count = len(os.listdir(class_dir))
            print(f"  {class_name}: {count} samples")
        else:
            print(f"  {class_name}: Directory does not exist.")

In [None]:
merge_datasets(source_base_dir, destination_dir)

normal_dest_dir = os.path.join(destination_dir, class_normal)
pneumonia_dest_dir = os.path.join(destination_dir, class_pneumonia)

balance_classes(normal_dest_dir, pneumonia_dest_dir)

print_class_counts(destination_dir)


Created/Verified destination directories in '../../2_classes_merged_dataset'.
Copied 1341 NORMAL images from split 'train'.
Copied 3875 PNEUMONIA images from split 'train'.
Copied 234 NORMAL images from split 'test'.
Copied 390 PNEUMONIA images from split 'test'.
Copied 8 NORMAL images from split 'val'.
Copied 8 PNEUMONIA images from split 'val'.

All splits have been merged into '2_classes_merged_dataset'.

Initial class counts:
  NORMAL: 1583 samples
  PNEUMONIA: 4273 samples

Balancing each class to 1583 samples.
No balancing needed for NORMAL.
Balanced PNEUMONIA: 1583 samples

Post-balancing class counts:
  NORMAL: 1583 samples
  PNEUMONIA: 1583 samples

Final class counts in '2_classes_merged_dataset':
  NORMAL: 1583 samples
  PNEUMONIA: 1583 samples

Examples of NORMAL vs PNEUMONIA images:

NORMAL examples:
  - NORMAL2-IM-1258-0001.jpeg
  - IM-0505-0001-0002.jpeg
  - IM-0036-0001.jpeg

PNEUMONIA examples:
  - person918_bacteria_2843.jpeg
  - person19_virus_50.jpeg
  - person1670