In [None]:
# ------------------------------------------------------------------------------
# LOAD DATA
# Load NIH Chest X-ray dataset labels from CSV
# ------------------------------------------------------------------------------

import pandas as pd
import os

# Load the CSV file that contains all the labels
csv_path = r'archive (1)\Data_Entry_2017.csv'
df = pd.read_csv(csv_path)

In [None]:
# ------------------------------------------------------------------------------
# PREPARE MULTI-CLASS CLASSIFICATION DATA (15 CLASSES)
# Extract and count all 15 disease categories from the dataset
# ------------------------------------------------------------------------------

from collections import Counter

# Get all 15 disease categories (excluding multi-disease combinations)
all_diseases = []
for labels in df['Finding Labels']:
    diseases = labels.split('|')
    all_diseases.extend(diseases)

disease_counts = Counter(all_diseases)

# Define the 15 classes (sorted alphabetically)
CLASS_NAMES = sorted([d for d in disease_counts.keys()])

print(f"Disease Categories: {len(CLASS_NAMES)}")
for i, class_name in enumerate(CLASS_NAMES):
    count = disease_counts[class_name]
    print(f"  [{i:2}] {class_name:25} - {count:6,} images")

In [None]:
# ------------------------------------------------------------------------------
# KEEP ALL IMAGES (MULTI-LABEL CLASSIFICATION)
# Prepare the dataset for multi-label classification
# ------------------------------------------------------------------------------

# Keep all images (don't filter out multi-label)
multilabel_df = df.copy()

# Calculate label statistics
has_multiple = df['Finding Labels'].str.contains('|', regex=False)
single_count = (~has_multiple).sum()
multi_count = has_multiple.sum()

print(f"Total images: {len(df):,}")
print(f"Single-label: {single_count:,} ({single_count/len(df)*100:.1f}%)")
print(f"Multi-label: {multi_count:,} ({multi_count/len(df)*100:.1f}%)")
print(f"\nUsing ALL images for multi-label classification!")

In [None]:
# ------------------------------------------------------------------------------
# LOAD SEGMENTATION DATA (BOUNDING BOXES)
# Load bounding box annotations for future segmentation tasks
# ------------------------------------------------------------------------------

bbox_path = r'archive (1)\BBox_List_2017.csv'
bbox_df = pd.read_csv(bbox_path)

# Merge with main dataframe
multilabel_df['Has_BBox'] = multilabel_df['Image Index'].isin(bbox_df['Image Index'])

print(f"Bounding box annotations: {len(bbox_df):,}")
print(f"Images with annotations: {multilabel_df['Has_BBox'].sum():,}")

In [None]:
# ------------------------------------------------------------------------------
# SAMPLE THE DATASET (HANDLE SIZE)
# Use a subset for faster training while keeping multi-label distribution
# ------------------------------------------------------------------------------

from sklearn.utils import resample

# Take 40,000 images (maintains multi-label distribution)
# This gives us ~28k train, ~6k val, ~6k test
MAX_TOTAL_IMAGES = 40000

if len(multilabel_df) > MAX_TOTAL_IMAGES:
    sampled_df = resample(multilabel_df, n_samples=MAX_TOTAL_IMAGES, random_state=42)
else:
    sampled_df = multilabel_df

sampled_df = sampled_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Sampled dataset: {len(sampled_df):,} images (multi-label)")
print(f"Single-label: {(~sampled_df['Finding Labels'].str.contains('|', regex=False)).sum():,}")
print(f"Multi-label: {sampled_df['Finding Labels'].str.contains('|', regex=False).sum():,}")

In [None]:
# ------------------------------------------------------------------------------
# CREATE TRAIN/VALIDATION/TEST SPLITS
# Random split: 70% train, 15% validation, 15% test 
# ------------------------------------------------------------------------------

from sklearn.model_selection import train_test_split

# For multi-label, we can't use stratify, so just use random splitting
train_df, temp_df = train_test_split(sampled_df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Train: {len(train_df):,} images ({train_df['Has_BBox'].sum():,} with bboxes)")
print(f"Val:   {len(val_df):,} images ({val_df['Has_BBox'].sum():,} with bboxes)")
print(f"Test:  {len(test_df):,} images ({test_df['Has_BBox'].sum():,} with bboxes)")

In [None]:
# ------------------------------------------------------------------------------
# CREATE FOLDER STRUCTURE
# All images go in the split folder, CSV has the labels
# ------------------------------------------------------------------------------

import shutil
from pathlib import Path

base_dir = Path('data')

# Create flat structure: data/train, data/val, data/test
for split in ['train', 'val', 'test']:
    # Create a temporary class folder so images can be found
    for class_name in CLASS_NAMES:
        safe_name = class_name.replace(' ', '_')
        (base_dir / split / safe_name).mkdir(parents=True, exist_ok=True)

print(f"Created folder structure for multi-label classification")

In [None]:
# ------------------------------------------------------------------------------
# ORGANIZE FILES (PARALLEL PROCESSING)
# Copy images to class folders using multi-threaded processing for speed
# ------------------------------------------------------------------------------

from tqdm import tqdm
import concurrent.futures
from functools import lru_cache
import glob

# Check if data is already organized
if (base_dir / 'train').exists() and len(list((base_dir / 'train').rglob('*.png'))) > 0:
    print(f"Data already organized: {len(list((base_dir / 'train').rglob('*.png')))} images found")
else:
    # Cache the find_image results to avoid repeated glob searches
    @lru_cache(maxsize=None)
    def find_image_cached(image_name):
        """Find which folder contains a specific image (cached)"""
        pattern = f"archive (1)/images_*/images/{image_name}"
        matches = glob.glob(pattern)
        if matches:
            return matches[0]
        return None
    
    def copy_single_image(args):
        """Copy a single image (for parallel processing)"""
        image_name, primary_class, split_name = args
        
        # Find source image
        source_path = find_image_cached(image_name)
        if source_path is None:
            return f"Warning: Could not find {image_name}"
        
        # For multi-label, put in primary class folder (first disease)
        # This allows the dataset class to find images by looking through folders
        dest_path = base_dir / split_name / primary_class / image_name
        try:
            shutil.copy2(source_path, dest_path)
            return None  # Success
        except Exception as e:
            return f"Error copying {image_name}: {e}"
    
    def copy_images_to_folders_parallel(dataframe, split_name, max_workers=8):
        """Copy images using parallel processing for speed"""
        # For multi-label, use first disease as primary class for folder organization
        args_list = [
            (row['Image Index'], row['Finding Labels'].split('|')[0].strip().replace(' ', '_'), split_name)
            for _, row in dataframe.iterrows()
        ]
        
        # Use ThreadPoolExecutor for I/O-bound file copying
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Process files in parallel with progress bar
            results = list(tqdm(
                executor.map(copy_single_image, args_list),
                total=len(args_list),
                desc=f"{split_name}"
            ))
            # Collect errors
            errors = [r for r in results if r is not None]
        
        return len(dataframe) - len(errors), len(errors)
    
    # Copy all splits with parallel processing
    import time
    start_time = time.time()
    
    train_success, train_errors = copy_images_to_folders_parallel(train_df, 'train', max_workers=8)
    val_success, val_errors = copy_images_to_folders_parallel(val_df, 'val', max_workers=8)
    test_success, test_errors = copy_images_to_folders_parallel(test_df, 'test', max_workers=8)
    
    elapsed = time.time() - start_time
    total_copied = train_success + val_success + test_success
    
    print(f"\nOrganized {total_copied:,} images in {elapsed:.1f}s ({total_copied/elapsed:.0f} img/s)")
    if train_errors + val_errors + test_errors > 0:
        print(f"Errors: {train_errors + val_errors + test_errors}")

In [None]:
# ------------------------------------------------------------------------------
# SAVE SEGMENTATION ANNOTATIONS
# Export bounding box data to JSON for each split
# ------------------------------------------------------------------------------

import json

# Save bounding box annotations for each split
for split_name, split_df in [('train', train_df), ('val', val_df), ('test', test_df)]:
    # Get images with bounding boxes
    images_with_bbox = split_df[split_df['Has_BBox']]['Image Index'].tolist()
    
    # Filter bbox data for this split
    split_bbox = bbox_df[bbox_df['Image Index'].isin(images_with_bbox)]
    
    bbox_dict = {}
    for _, row in split_bbox.iterrows():
        img_name = row['Image Index']
        if img_name not in bbox_dict:
            bbox_dict[img_name] = []
        bbox_dict[img_name].append({
            'class': row['Finding Label'],
            'x': row['Bbox [x'],
            'y': row['y'],
            'w': row['w'],
            'h': row['h]']
        })
    
    # Save to file
    with open(f'data/{split_name}_bboxes.json', 'w') as f:
        json.dump(bbox_dict, f, indent=2)

In [None]:
# ------------------------------------------------------------------------------
# SAVE DATASET METADATA & CLASS MAPPING
# Generate class mapping and summary statistics for training
# ------------------------------------------------------------------------------

import json

# Save class mapping
class_mapping = {i: class_name for i, class_name in enumerate(CLASS_NAMES)}
with open('data/class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

# Save metadata for each split
train_df.to_csv('data/train_metadata.csv', index=False)
val_df.to_csv('data/val_metadata.csv', index=False)
test_df.to_csv('data/test_metadata.csv', index=False)

# Create comprehensive summary
summary = {
    'num_classes': len(CLASS_NAMES),
    'class_names': CLASS_NAMES,
    'total_images': len(sampled_df),
    'train_images': len(train_df),
    'val_images': len(val_df),
    'test_images': len(test_df),
    'images_with_bboxes_train': int(train_df['Has_BBox'].sum()),
    'images_with_bboxes_val': int(val_df['Has_BBox'].sum()),
    'images_with_bboxes_test': int(test_df['Has_BBox'].sum()),
    'date_processed': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    'task': 'multi-label classification (15 classes) + segmentation',
    'multi_label': True,
    'single_label_count': int((~sampled_df['Finding Labels'].str.contains('|', regex=False)).sum()),
    'multi_label_count': int(sampled_df['Finding Labels'].str.contains('|', regex=False).sum())
}

with open('data/dataset_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("Metadata saved: class_mapping.json, CSVs, bboxes, summary")
print(f"\nIMPORTANT: This is MULTI-LABEL data!")
print(f"- CSV files contain 'Finding Labels' with pipe-separated diseases")
print(f"- Each image can have 0-15 diseases")
print(f"- Use BCEWithLogitsLoss for training")