In [None]:
# ------------------------------------------------------------------------------
# LOAD DATA
# Load NIH Chest X-ray dataset labels from CSV
# ------------------------------------------------------------------------------

import pandas as pd
import os

# Load the CSV file that contains all the labels
csv_path = r'archive (1)\Data_Entry_2017.csv'
df = pd.read_csv(csv_path)

In [None]:
# ------------------------------------------------------------------------------
# PREPARE MULTI-CLASS CLASSIFICATION DATA (15 CLASSES)
# Extract and count all 15 disease categories from the dataset
# ------------------------------------------------------------------------------

from collections import Counter

# Get all 15 disease categories (excluding multi-disease combinations)
all_diseases = []
for labels in df['Finding Labels']:
    diseases = labels.split('|')
    all_diseases.extend(diseases)

disease_counts = Counter(all_diseases)

# Define the 15 classes (sorted alphabetically)
CLASS_NAMES = sorted([d for d in disease_counts.keys()])

print(f"Disease Categories: {len(CLASS_NAMES)}")
for i, class_name in enumerate(CLASS_NAMES):
    count = disease_counts[class_name]
    print(f"  [{i:2}] {class_name:25} - {count:6,} images")

Disease Categories: 15
  [ 0] Atelectasis               - 11,559 images
  [ 1] Cardiomegaly              -  2,776 images
  [ 2] Consolidation             -  4,667 images
  [ 3] Edema                     -  2,303 images
  [ 4] Effusion                  - 13,317 images
  [ 5] Emphysema                 -  2,516 images
  [ 6] Fibrosis                  -  1,686 images
  [ 7] Hernia                    -    227 images
  [ 8] Infiltration              - 19,894 images
  [ 9] Mass                      -  5,782 images
  [10] No Finding                - 60,361 images
  [11] Nodule                    -  6,331 images
  [12] Pleural_Thickening        -  3,385 images
  [13] Pneumonia                 -  1,431 images
  [14] Pneumothorax              -  5,302 images


In [None]:
# ------------------------------------------------------------------------------
# FILTER SINGLE-LABEL IMAGES
# Remove multi-label images to ensure clean single-class classification
# ------------------------------------------------------------------------------

single_label_df = df[~df['Finding Labels'].str.contains('|', regex=False)].copy()

# Add numeric class labels
single_label_df['Class_Label'] = single_label_df['Finding Labels'].apply(
    lambda x: CLASS_NAMES.index(x)
)

print(f"Total images: {len(df):,}")
print(f"Single-label: {len(single_label_df):,}")
print(f"Multi-label: {len(df) - len(single_label_df):,}")

Total images: 112,120
Single-label: 91,324
Multi-label: 20,796


In [None]:
# ------------------------------------------------------------------------------
# LOAD SEGMENTATION DATA (BOUNDING BOXES)
# Load bounding box annotations for future segmentation tasks
# ------------------------------------------------------------------------------

bbox_path = r'archive (1)\BBox_List_2017.csv'
bbox_df = pd.read_csv(bbox_path)

# Merge with main dataframe
single_label_df['Has_BBox'] = single_label_df['Image Index'].isin(bbox_df['Image Index'])

print(f"Bounding box annotations: {len(bbox_df):,}")
print(f"Images with annotations: {single_label_df['Has_BBox'].sum():,}")

Bounding box annotations: 984
Images with annotations: 256


In [None]:
# ------------------------------------------------------------------------------
# BALANCE THE DATASET (HANDLE CLASS IMBALANCE)
# Undersample majority classes to prevent model bias toward common diseases
# ------------------------------------------------------------------------------

from sklearn.utils import resample

# Balance dataset by limiting samples per class
MAX_SAMPLES_PER_CLASS = 5000

balanced_dfs = []
for class_name in CLASS_NAMES:
    class_data = single_label_df[single_label_df['Finding Labels'] == class_name]
    
    if len(class_data) > MAX_SAMPLES_PER_CLASS:
        class_data = resample(class_data, n_samples=MAX_SAMPLES_PER_CLASS, random_state=42)
    
    balanced_dfs.append(class_data)

balanced_df = pd.concat(balanced_dfs)
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Balanced dataset: {len(balanced_df):,} images ({len(CLASS_NAMES)} classes)")

Balanced dataset: 31,416 images (15 classes)


In [None]:
# ------------------------------------------------------------------------------
# CREATE TRAIN/VALIDATION/TEST SPLITS
# Stratified split: 70% train, 15% validation, 15% test
# ------------------------------------------------------------------------------

from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(balanced_df, test_size=0.3, 
                                      stratify=balanced_df['Finding Labels'], 
                                      random_state=42)

val_df, test_df = train_test_split(temp_df, test_size=0.5, 
                                    stratify=temp_df['Finding Labels'], 
                                    random_state=42)

print(f"Train: {len(train_df):,} images ({train_df['Has_BBox'].sum():,} with bboxes)")
print(f"Val:   {len(val_df):,} images ({val_df['Has_BBox'].sum():,} with bboxes)")
print(f"Test:  {len(test_df):,} images ({test_df['Has_BBox'].sum():,} with bboxes)")

Train: 21,991 images (171 with bboxes)
Val:   4,712 images (43 with bboxes)
Test:  4,713 images (36 with bboxes)


In [None]:
# ------------------------------------------------------------------------------
# CREATE FOLDER STRUCTURE (15 CLASSES)
# Build directory hierarchy for PyTorch ImageFolder DataLoader
# ------------------------------------------------------------------------------

import shutil
from pathlib import Path

base_dir = Path('data')

for split in ['train', 'val', 'test']:
    for class_name in CLASS_NAMES:
        safe_name = class_name.replace(' ', '_')
        (base_dir / split / safe_name).mkdir(parents=True, exist_ok=True)

print(f"Created folder structure: data/{{train,val,test}}/{{15 classes}}")

Created folder structure: data/{train,val,test}/{15 classes}


In [None]:
# ------------------------------------------------------------------------------
# ORGANIZE FILES (PARALLEL PROCESSING)
# Copy images to class folders using multi-threaded processing for speed
# ------------------------------------------------------------------------------

from tqdm import tqdm
import concurrent.futures
from functools import lru_cache
import glob

# Check if data is already organized
if (base_dir / 'train').exists() and len(list((base_dir / 'train').rglob('*.png'))) > 0:
    print(f"Data already organized: {len(list((base_dir / 'train').rglob('*.png')))} images found")
else:
    # Cache the find_image results to avoid repeated glob searches
    @lru_cache(maxsize=None)
    def find_image_cached(image_name):
        """Find which folder contains a specific image (cached)"""
        pattern = f"archive (1)/images_*/images/{image_name}"
        matches = glob.glob(pattern)
        if matches:
            return matches[0]
        return None
    
    def copy_single_image(args):
        """Copy a single image (for parallel processing)"""
        image_name, class_label, split_name = args
        
        # Find source image
        source_path = find_image_cached(image_name)
        if source_path is None:
            return f"Warning: Could not find {image_name}"
        
        # Copy to destination
        dest_path = base_dir / split_name / class_label / image_name
        try:
            shutil.copy2(source_path, dest_path)
            return None  # Success
        except Exception as e:
            return f"Error copying {image_name}: {e}"
    
    def copy_images_to_folders_parallel(dataframe, split_name, max_workers=8):
        """Copy images using parallel processing for speed"""
        # Prepare arguments for parallel processing
        args_list = [
            (row['Image Index'], row['Finding Labels'].replace(' ', '_'), split_name)
            for _, row in dataframe.iterrows()
        ]
        
        # Use ThreadPoolExecutor for I/O-bound file copying
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Process files in parallel with progress bar
            results = list(tqdm(
                executor.map(copy_single_image, args_list),
                total=len(args_list),
                desc=f"{split_name}"
            ))
            # Collect errors
            errors = [r for r in results if r is not None]
        
        return len(dataframe) - len(errors), len(errors)
    
    # Copy all splits with parallel processing
    import time
    start_time = time.time()
    
    train_success, train_errors = copy_images_to_folders_parallel(train_df, 'train', max_workers=8)
    val_success, val_errors = copy_images_to_folders_parallel(val_df, 'val', max_workers=8)
    test_success, test_errors = copy_images_to_folders_parallel(test_df, 'test', max_workers=8)
    
    elapsed = time.time() - start_time
    total_copied = train_success + val_success + test_success
    
    print(f"\nOrganized {total_copied:,} images in {elapsed:.1f}s ({total_copied/elapsed:.0f} img/s)")
    if train_errors + val_errors + test_errors > 0:
        print(f"Errors: {train_errors + val_errors + test_errors}")

Data already organized: 21287 images found



In [None]:
# ------------------------------------------------------------------------------
# SAVE SEGMENTATION ANNOTATIONS
# Export bounding box data to JSON for each split
# ------------------------------------------------------------------------------

import json

# Save bounding box annotations for each split
for split_name, split_df in [('train', train_df), ('val', val_df), ('test', test_df)]:
    # Get images with bounding boxes
    images_with_bbox = split_df[split_df['Has_BBox']]['Image Index'].tolist()
    
    # Filter bbox data for this split
    split_bbox = bbox_df[bbox_df['Image Index'].isin(images_with_bbox)]
    
    bbox_dict = {}
    for _, row in split_bbox.iterrows():
        img_name = row['Image Index']
        if img_name not in bbox_dict:
            bbox_dict[img_name] = []
        bbox_dict[img_name].append({
            'class': row['Finding Label'],
            'x': row['Bbox [x'],
            'y': row['y'],
            'w': row['w'],
            'h': row['h]']
        })
    
    # Save to file
    with open(f'data/{split_name}_bboxes.json', 'w') as f:
        json.dump(bbox_dict, f, indent=2)

In [None]:
# ------------------------------------------------------------------------------
# SAVE DATASET METADATA & CLASS MAPPING
# Generate class mapping and summary statistics for training
# ------------------------------------------------------------------------------

import json

# Save class mapping
class_mapping = {i: class_name for i, class_name in enumerate(CLASS_NAMES)}
with open('data/class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

# Save metadata for each split
train_df.to_csv('data/train_metadata.csv', index=False)
val_df.to_csv('data/val_metadata.csv', index=False)
test_df.to_csv('data/test_metadata.csv', index=False)

# Create comprehensive summary
summary = {
    'num_classes': len(CLASS_NAMES),
    'class_names': CLASS_NAMES,
    'total_images': len(balanced_df),
    'train_images': len(train_df),
    'val_images': len(val_df),
    'test_images': len(test_df),
    'images_with_bboxes_train': int(train_df['Has_BBox'].sum()),
    'images_with_bboxes_val': int(val_df['Has_BBox'].sum()),
    'images_with_bboxes_test': int(test_df['Has_BBox'].sum()),
    'date_processed': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    'task': 'multi-class classification (15 classes) + segmentation'
}

with open('data/dataset_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("Metadata saved: class_mapping.json, CSVs, bboxes, summary")

Metadata saved: class_mapping.json, CSVs, bboxes, summary
