# Data Preparation - 2 Classes Only (Aortic Enlargement, Cardiomegaly)

**Goal:** Prepare dataset với chỉ 2 classes:
- Aortic enlargement (Phình động mạch chủ)
- Cardiomegaly (Tim to)

**No Normal class** - YOLO tự học từ negative samples

## 1. Setup và Imports

In [11]:
# Set working directory
%cd /home/minhquana/workspace/project_DeepLearning/computer_vision/Abnormal-Prediction-In-Chest-X-Ray

/home/minhquana/workspace/project_DeepLearning/computer_vision/Abnormal-Prediction-In-Chest-X-Ray


In [12]:
# Import libraries
import os
import shutil
import json
import yaml
from pathlib import Path
from collections import Counter, defaultdict
from typing import Dict, List, Set

import numpy as np
from PIL import Image
from tqdm import tqdm

# Import preprocessing
import sys
sys.path.insert(0, str(Path.cwd()))
from backend.src.utils.preprocessing import preprocess_image

print("✓ Imports successful")

✓ Imports successful


## 2. Configuration

In [13]:
# Paths
DATA_DIR = Path("data")
SOURCE_DIR = DATA_DIR  
OUTPUT_DIR = DATA_DIR / "preprocessed_2classes"

# Classes to keep (chỉ 2 classes)
CLASSES_TO_KEEP = [
    "Aortic enlargement",
    "Cardiomegaly"
]

# Vietnamese mapping
CLASS_MAPPING_VI = {
    "Aortic enlargement": "Phình động mạch chủ",
    "Cardiomegaly": "Tim to"
}

print("Configuration:")
print(f"  Source: {SOURCE_DIR}")
print(f"  Output: {OUTPUT_DIR}")
print(f"  Classes to keep: {len(CLASSES_TO_KEEP)}")
for cls in CLASSES_TO_KEEP:
    print(f"    - {cls} ({CLASS_MAPPING_VI[cls]})")

Configuration:
  Source: data
  Output: data/preprocessed_2classes
  Classes to keep: 2
    - Aortic enlargement (Phình động mạch chủ)
    - Cardiomegaly (Tim to)


## 3. Load Original Dataset Info

In [14]:
# Load original data.yaml
data_yaml_path = SOURCE_DIR / "data.yaml"

with open(data_yaml_path, 'r') as f:
    original_data_config = yaml.safe_load(f)

print("Original Dataset:")
print(f"  Classes: {original_data_config['nc']}")
print(f"  Names: {original_data_config['names']}")

# Get class indices for classes to keep
original_class_names = original_data_config['names']
class_indices_to_keep = [
    original_class_names.index(cls) for cls in CLASSES_TO_KEEP
]

print(f"\nClass indices to keep: {class_indices_to_keep}")

Original Dataset:
  Classes: 14
  Names: ['Aortic enlargement', 'Atelectasis', 'Calcification', 'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration', 'Lung Opacity', 'Nodule-Mass', 'Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax', 'Pulmonary fibrosis']

Class indices to keep: [0, 3]


## 4. Analyze Dataset - Count Images per Class

In [15]:
def count_images_per_class(split_dir: Path, label_dir: Path, class_indices: List[int]) -> Dict[int, int]:
    """
    Count images per class (một image có thể chứa nhiều classes).
    
    Returns:
        Dict with keys: 
        - Per class: {class_id: count}
        - 'unique_images': Total unique images containing any target class
    """
    class_counts = defaultdict(int)
    unique_images = set()  # Track unique images with target classes
    
    label_files = list(label_dir.glob('*.txt'))
    
    for label_file in tqdm(label_files, desc=f"Counting {split_dir.name}"):
        with open(label_file, 'r') as f:
            lines = f.readlines()
            
        # Get unique class IDs in this file
        classes_in_file = set()
        for line in lines:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                if class_id in class_indices:
                    classes_in_file.add(class_id)
        
        # If this image has any target class, add to unique set
        if classes_in_file:
            unique_images.add(label_file.stem)
        
        # Increment count for each class found
        for class_id in classes_in_file:
            class_counts[class_id] += 1
    
    result = dict(class_counts)
    result['unique_images'] = len(unique_images)
    return result

# Count for each split
print("\nAnalyzing dataset...")
print("=" * 80)

for split in ['train', 'valid', 'test']:
    split_dir = SOURCE_DIR / split / 'images'
    label_dir = SOURCE_DIR / split / 'labels'
    
    if not split_dir.exists():
        print(f"\n{split.upper()}: Not found")
        continue
    
    # Count images per class
    class_counts = count_images_per_class(split_dir, label_dir, class_indices_to_keep)
    
    print(f"\n{split.upper()}:")
    per_class_total = 0
    for cls_idx, cls_name in zip(class_indices_to_keep, CLASSES_TO_KEEP):
        count = class_counts.get(cls_idx, 0)
        per_class_total += count
        print(f"  {cls_name}: {count:,} images")
    
    unique_count = class_counts.get('unique_images', 0)
    overlap_count = per_class_total - unique_count
    
    print(f"  ---")
    print(f"  Per-class total: {per_class_total:,} (có overlap)")
    print(f"  Unique images: {unique_count:,} (actual count)")
    print(f"  Overlap (images with both classes): {overlap_count:,}")

print("\n" + "=" * 80)


Analyzing dataset...


Counting images:   0%|          | 0/10499 [00:00<?, ?it/s]

Counting images: 100%|██████████| 10499/10499 [00:00<00:00, 88183.11it/s]


TRAIN:
  Aortic enlargement: 2,134 images
  Cardiomegaly: 1,590 images
  ---
  Per-class total: 3,724 (có overlap)
  Unique images: 2,362 (actual count)
  Overlap (images with both classes): 1,362



Counting images: 100%|██████████| 3000/3000 [00:00<00:00, 97872.75it/s]
Counting images: 100%|██████████| 3000/3000 [00:00<00:00, 97872.75it/s]



VALID:
  Aortic enlargement: 632 images
  Cardiomegaly: 492 images
  ---
  Per-class total: 1,124 (có overlap)
  Unique images: 704 (actual count)
  Overlap (images with both classes): 420


Counting images: 100%|██████████| 1499/1499 [00:00<00:00, 95316.42it/s]


TEST:
  Aortic enlargement: 301 images
  Cardiomegaly: 218 images
  ---
  Per-class total: 519 (có overlap)
  Unique images: 328 (actual count)
  Overlap (images with both classes): 191






## 5. Filter và Copy Images

Chỉ copy các ảnh có chứa ít nhất 1 trong 2 classes cần giữ.

In [16]:
def filter_and_copy_data(source_dir: Path, output_dir: Path, class_indices: List[int], split: str):

    # Create output directories
    output_img_dir = output_dir / split / 'images'
    output_lbl_dir = output_dir / split / 'labels'
    output_img_dir.mkdir(parents=True, exist_ok=True)
    output_lbl_dir.mkdir(parents=True, exist_ok=True)
    
    # Source directories
    source_img_dir = source_dir / split / 'images'
    source_lbl_dir = source_dir / split / 'labels'
    
    if not source_img_dir.exists():
        print(f"  {split}: Not found, skipping")
        return
    
    # Get all label files
    label_files = list(source_lbl_dir.glob('*.txt'))
    
    # Create mapping from old class indices to new (0, 1)
    class_id_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(class_indices)}
    
    copied_count = 0
    skipped_count = 0
    
    for label_file in tqdm(label_files, desc=f"Filtering {split}"):
        # Read label file
        with open(label_file, 'r') as f:
            lines = f.readlines()
        
        # Filter lines - keep only lines with target classes
        filtered_lines = []
        for line in lines:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                if class_id in class_indices:
                    # Remap class ID (0 or 1)
                    new_class_id = class_id_mapping[class_id]
                    filtered_lines.append(f"{new_class_id} {' '.join(parts[1:])}\n")
        
        # Skip if no relevant classes found
        if not filtered_lines:
            skipped_count += 1
            continue
        
        # Find corresponding image file
        img_name = label_file.stem
        img_file = None
        for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
            candidate = source_img_dir / f"{img_name}{ext}"
            if candidate.exists():
                img_file = candidate
                break
        
        if img_file is None:
            print(f"    Warning: Image not found for {img_name}")
            continue
        
        # Copy image
        shutil.copy2(img_file, output_img_dir / img_file.name)
        
        # Write filtered label file
        output_label_file = output_lbl_dir / label_file.name
        with open(output_label_file, 'w') as f:
            f.writelines(filtered_lines)
        
        copied_count += 1
    
    print(f"  {split}: Copied {copied_count:,} images, skipped {skipped_count:,}")
    return copied_count

# Filter and copy data
print("\nFiltering and copying data...")
print("=" * 80)

total_copied = 0
for split in ['train', 'valid', 'test']:
    count = filter_and_copy_data(SOURCE_DIR, OUTPUT_DIR, class_indices_to_keep, split)
    if count:
        total_copied += count

print(f"\nTotal images copied: {total_copied:,}")
print("=" * 80)


Filtering and copying data...


Filtering train: 100%|██████████| 10499/10499 [00:00<00:00, 29480.56it/s]
Filtering train: 100%|██████████| 10499/10499 [00:00<00:00, 29480.56it/s]


  train: Copied 2,362 images, skipped 8,137


Filtering valid: 100%|██████████| 3000/3000 [00:00<00:00, 28000.48it/s]



  valid: Copied 704 images, skipped 2,296


Filtering test: 100%|██████████| 1499/1499 [00:00<00:00, 26898.53it/s]

  test: Copied 328 images, skipped 1,171

Total images copied: 3,394





## 5.5. Find and Sample Negative Samples 

In [17]:
import random

def find_negative_samples(source_dir: Path, split: str, class_indices: List[int]) -> List[Path]:
    """
    Tìm các ảnh không có label hoặc không chứa các classes cần giữ.
    
    Args:
        source_dir: Source data directory
        split: Dataset split (train/valid/test)
        class_indices: List of class indices we're keeping
        
    Returns:
        List of image paths that are negative samples
    """
    source_img_dir = source_dir / split / 'images'
    source_lbl_dir = source_dir / split / 'labels'
    
    if not source_img_dir.exists():
        return []
    
    negative_samples = []
    
    # Get all images
    image_files = list(source_img_dir.glob('*.jpg')) + list(source_img_dir.glob('*.png'))
    
    for img_file in tqdm(image_files, desc=f"Finding negative samples in {split}"):
        label_file = source_lbl_dir / (img_file.stem + '.txt')
        
        # Case 1: Label file doesn't exist
        if not label_file.exists():
            negative_samples.append(img_file)
            continue
        
        # Case 2: Label file is empty
        if label_file.stat().st_size == 0:
            negative_samples.append(img_file)
            continue
        
        # Case 3: Label file only contains classes we're NOT keeping
        with open(label_file, 'r') as f:
            lines = f.readlines()
        
        has_target_class = False
        for line in lines:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                if class_id in class_indices:
                    has_target_class = True
                    break
        
        if not has_target_class:
            negative_samples.append(img_file)
    
    return negative_samples


def sample_negative_samples(
    negative_samples: List[Path],
    target_count: int,
    random_seed: int = 42
) -> List[Path]:

    random.seed(random_seed)
    
    if len(negative_samples) <= target_count:
        return negative_samples
    
    return random.sample(negative_samples, target_count)


# Find negative samples for each split
print("\nFinding negative samples...")
print("=" * 80)

negative_samples = {}
for split in ['train', 'valid', 'test']:
    neg_samples = find_negative_samples(SOURCE_DIR, split, class_indices_to_keep)
    negative_samples[split] = neg_samples
    print(f"  {split.upper()}: Found {len(neg_samples):,} negative samples")

print("=" * 80)


Finding negative samples...


Finding negative samples in train:   0%|          | 0/10499 [00:00<?, ?it/s]

Finding negative samples in train: 100%|██████████| 10499/10499 [00:00<00:00, 68839.91it/s]



  TRAIN: Found 8,137 negative samples


Finding negative samples in valid: 100%|██████████| 3000/3000 [00:00<00:00, 71841.19it/s]
Finding negative samples in valid: 100%|██████████| 3000/3000 [00:00<00:00, 71841.19it/s]


  VALID: Found 2,296 negative samples


Finding negative samples in test: 100%|██████████| 1499/1499 [00:00<00:00, 73348.21it/s]

  TEST: Found 1,171 negative samples





In [18]:
# Sample negative samples để cân bằng với positive samples
print("\nSampling negative samples to balance dataset...")
print("=" * 80)

# Get counts of images already copied (positive samples)
positive_counts = {}
for split in ['train', 'valid', 'test']:
    img_dir = OUTPUT_DIR / split / 'images'
    if img_dir.exists():
        count = len(list(img_dir.glob('*.jpg'))) + len(list(img_dir.glob('*.png')))
        positive_counts[split] = count
    else:
        positive_counts[split] = 0

# Sample negative samples (equal to positive samples)
sampled_negative = {}
for split in ['train', 'valid', 'test']:
    target_count = round(positive_counts[split] / 2)  # Match positive sample count
    sampled = sample_negative_samples(negative_samples[split], target_count)
    sampled_negative[split] = sampled
    
    print(f"  {split.upper()}:")
    print(f"    Positive samples: {positive_counts[split]:,}")
    print(f"    Available negative samples: {len(negative_samples[split]):,}")
    print(f"    Sampled negative samples: {len(sampled):,}")
    print(f"    Total after balance: {positive_counts[split] + len(sampled):,}")

print("=" * 80)


Sampling negative samples to balance dataset...
  TRAIN:
    Positive samples: 2,362
    Available negative samples: 8,137
    Sampled negative samples: 1,181
    Total after balance: 3,543
  VALID:
    Positive samples: 704
    Available negative samples: 2,296
    Sampled negative samples: 352
    Total after balance: 1,056
  TEST:
    Positive samples: 328
    Available negative samples: 1,171
    Sampled negative samples: 164
    Total after balance: 492


In [19]:
# Copy negative samples vào output directory
print("\nCopying negative samples...")
print("=" * 80)

negative_copied = {}
for split in ['train', 'valid', 'test']:
    output_img_dir = OUTPUT_DIR / split / 'images'
    output_lbl_dir = OUTPUT_DIR / split / 'labels'
    
    copied = 0
    for img_file in tqdm(sampled_negative[split], desc=f"Copying {split} negative samples"):
        # Copy image
        shutil.copy2(img_file, output_img_dir / img_file.name)
        
        # Create empty label file (no bounding boxes for negative samples)
        label_file = output_lbl_dir / (img_file.stem + '.txt')
        label_file.touch()  # Create empty file
        
        copied += 1
    
    negative_copied[split] = copied
    print(f"  {split.upper()}: Copied {copied:,} negative samples")

print("=" * 80)
print(f"\nTotal negative samples copied: {sum(negative_copied.values()):,}")


Copying negative samples...


Copying train negative samples: 100%|██████████| 1181/1181 [00:00<00:00, 10840.56it/s]
Copying train negative samples: 100%|██████████| 1181/1181 [00:00<00:00, 10840.56it/s]


  TRAIN: Copied 1,181 negative samples


Copying valid negative samples: 100%|██████████| 352/352 [00:00<00:00, 7018.45it/s]
Copying valid negative samples: 100%|██████████| 352/352 [00:00<00:00, 7018.45it/s]


  VALID: Copied 352 negative samples


Copying test negative samples: 100%|██████████| 164/164 [00:00<00:00, 8198.93it/s]

  TEST: Copied 164 negative samples

Total negative samples copied: 1,697





## 6. Apply Preprocessing

Apply standard preprocessing pipeline to all images:
1. Grayscale conversion
2. Histogram equalization
3. Normalization

In [20]:
def apply_preprocessing_to_split(output_dir: Path, split: str):
    img_dir = output_dir / split / 'images'
    
    if not img_dir.exists():
        print(f"  {split}: Not found, skipping")
        return
    
    image_files = list(img_dir.glob('*.jpg')) + list(img_dir.glob('*.png'))
    
    for img_file in tqdm(image_files, desc=f"Preprocessing {split}"):
        # Load image
        img = Image.open(img_file)
        img_array = np.array(img)
        
        # Apply preprocessing (grayscale, histogram eq, normalization)
        preprocessed = preprocess_image(
            img_array, 
            target_size=None,  # Keep original size
            apply_normalization=True
        )
        
        # Convert back to uint8 for saving
        preprocessed_uint8 = (preprocessed * 255).astype(np.uint8)
        
        # Save (overwrite original)
        Image.fromarray(preprocessed_uint8).save(img_file)
    
    print(f"  {split}: Preprocessed {len(image_files):,} images")

# Apply preprocessing
print("\nApplying preprocessing...")
print("=" * 80)

for split in ['train', 'valid', 'test']:
    apply_preprocessing_to_split(OUTPUT_DIR, split)

print("\n✓ Preprocessing complete")
print("=" * 80)


Applying preprocessing...


Preprocessing train:   0%|          | 0/3543 [00:00<?, ?it/s]

Preprocessing train: 100%|██████████| 3543/3543 [11:33<00:00,  5.11it/s]
Preprocessing train: 100%|██████████| 3543/3543 [11:33<00:00,  5.11it/s]


  train: Preprocessed 3,543 images


Preprocessing valid: 100%|██████████| 1056/1056 [03:29<00:00,  5.04it/s]
Preprocessing valid: 100%|██████████| 1056/1056 [03:29<00:00,  5.04it/s]


  valid: Preprocessed 1,056 images


Preprocessing test: 100%|██████████| 492/492 [01:36<00:00,  5.10it/s]

  test: Preprocessed 492 images

✓ Preprocessing complete





## 7. Create data.yaml Files

In [21]:
# Create data.yaml (English)
data_yaml = {
    'path': str(OUTPUT_DIR.absolute()),
    'train': 'train/images',
    'val': 'valid/images',
    'test': 'test/images',
    'nc': len(CLASSES_TO_KEEP),
    'names': CLASSES_TO_KEEP
}

data_yaml_path = OUTPUT_DIR / 'data.yaml'
with open(data_yaml_path, 'w') as f:
    yaml.dump(data_yaml, f, default_flow_style=False, sort_keys=False)

print(f"✓ Created {data_yaml_path}")
print("\nContents:")
with open(data_yaml_path, 'r') as f:
    print(f.read())

✓ Created data/preprocessed_2classes/data.yaml

Contents:
path: /home/minhquana/workspace/project_DeepLearning/computer_vision/Abnormal-Prediction-In-Chest-X-Ray/data/preprocessed_2classes
train: train/images
val: valid/images
test: test/images
nc: 2
names:
- Aortic enlargement
- Cardiomegaly



In [22]:
# Create data_vi.yaml (Vietnamese)
data_yaml_vi = {
    'path': str(OUTPUT_DIR.absolute()),
    'train': 'train/images',
    'val': 'valid/images',
    'test': 'test/images',
    'nc': len(CLASSES_TO_KEEP),
    'names': [CLASS_MAPPING_VI[cls] for cls in CLASSES_TO_KEEP]
}

data_yaml_vi_path = OUTPUT_DIR / 'data_vi.yaml'
with open(data_yaml_vi_path, 'w', encoding='utf-8') as f:
    yaml.dump(data_yaml_vi, f, default_flow_style=False, sort_keys=False, allow_unicode=True)

print(f"✓ Created {data_yaml_vi_path}")
print("\nContents:")
with open(data_yaml_vi_path, 'r', encoding='utf-8') as f:
    print(f.read())

✓ Created data/preprocessed_2classes/data_vi.yaml

Contents:
path: /home/minhquana/workspace/project_DeepLearning/computer_vision/Abnormal-Prediction-In-Chest-X-Ray/data/preprocessed_2classes
train: train/images
val: valid/images
test: test/images
nc: 2
names:
- Phình động mạch chủ
- Tim to



## 8. Update Class Mapping Configs

In [23]:
# Update configs/class_mapping.json
config_dir = Path("configs")
config_dir.mkdir(exist_ok=True)

class_mapping_path = config_dir / 'class_mapping_2classes.json'
with open(class_mapping_path, 'w', encoding='utf-8') as f:
    json.dump(CLASS_MAPPING_VI, f, ensure_ascii=False, indent=2)

print(f"✓ Created {class_mapping_path}")
print("\nContents:")
print(json.dumps(CLASS_MAPPING_VI, ensure_ascii=False, indent=2))

✓ Created configs/class_mapping_2classes.json

Contents:
{
  "Aortic enlargement": "Phình động mạch chủ",
  "Cardiomegaly": "Tim to"
}


## 9. Summary

In [24]:
print("\n" + "=" * 80)
print("DATA PREPARATION COMPLETE")
print("=" * 80)

print(f"\nOutput directory: {OUTPUT_DIR}")
print(f"\nClasses (nc={len(CLASSES_TO_KEEP)}):")
for i, (en, vi) in enumerate(CLASS_MAPPING_VI.items()):
    print(f"  {i}: {en} ({vi})")

print("\nDataset splits:")
for split in ['train', 'valid', 'test']:
    img_dir = OUTPUT_DIR / split / 'images'
    if img_dir.exists():
        count = len(list(img_dir.glob('*.jpg'))) + len(list(img_dir.glob('*.png')))
        print(f"  {split}: {count:,} images")

print("\nFiles created:")
print(f"  - {data_yaml_path}")
print(f"  - {data_yaml_vi_path}")
print(f"  - {class_mapping_path}")

print("\nNext steps:")
print("  1. Review preprocessed images in:", OUTPUT_DIR)
print("  2. Train model using: notebooks/train_yolov11s.ipynb")
print(f"  3. Use data.yaml: {data_yaml_path}")

print("\n" + "=" * 80)


DATA PREPARATION COMPLETE

Output directory: data/preprocessed_2classes

Classes (nc=2):
  0: Aortic enlargement (Phình động mạch chủ)
  1: Cardiomegaly (Tim to)

Dataset splits:
  train: 3,543 images
  valid: 1,056 images
  test: 492 images

Files created:
  - data/preprocessed_2classes/data.yaml
  - data/preprocessed_2classes/data_vi.yaml
  - configs/class_mapping_2classes.json

Next steps:
  1. Review preprocessed images in: data/preprocessed_2classes
  2. Train model using: notebooks/train_yolov11s.ipynb
  3. Use data.yaml: data/preprocessed_2classes/data.yaml



## 10. (Optional) Create Augmented Dataset with Gaussian Blur

Tạo augmented version của training data với Gaussian blur.
- Chỉ augment **training set**
- Mỗi ảnh tạo 1 augmented version
- Lưu vào `data/preprocessed_2classes_aug/`

**Lưu ý:** Section này là OPTIONAL.

In [25]:
def create_augmented_dataset(
    source_dir: Path,
    output_dir: Path,
    augment_train_only: bool = True,
    num_augmentations: int = 1,
):
    """
    Create augmented dataset with Gaussian blur.
    
    Args:
        source_dir: Source preprocessed directory
        output_dir: Output directory for augmented data
        augment_train_only: Only augment training set
        num_augmentations: Number of augmented versions per image
    """
    from backend.src.utils.augmentation import augment_image
    
    print(f"\nCreating Augmented Dataset")
    print("=" * 80)
    print(f"  Source: {source_dir}")
    print(f"  Output: {output_dir}")
    print(f"  Augment train only: {augment_train_only}")
    print(f"  Augmentations per image: {num_augmentations}")
    print("=" * 80)
    
    # Determine which splits to augment
    splits_to_augment = ['train'] if augment_train_only else ['train', 'valid', 'test']
    all_splits = ['train', 'valid', 'test']
    
    aug_stats = {}
    
    for split in all_splits:
        source_images_dir = source_dir / split / 'images'
        source_labels_dir = source_dir / split / 'labels'
        
        output_images_dir = output_dir / split / 'images'
        output_labels_dir = output_dir / split / 'labels'
        
        output_images_dir.mkdir(parents=True, exist_ok=True)
        output_labels_dir.mkdir(parents=True, exist_ok=True)
        
        if not source_images_dir.exists():
            continue
        
        print(f"\nProcessing {split.upper()} split...")
        
        # Get all images
        image_files = list(source_images_dir.glob('*.jpg')) + list(source_images_dir.glob('*.png'))
        
        original_count = 0
        augmented_count = 0
        
        # Copy original images
        for img_path in tqdm(image_files, desc=f"  Copying originals"):
            # Copy image
            shutil.copy(img_path, output_images_dir / img_path.name)
            
            # Copy label
            label_path = source_labels_dir / (img_path.stem + '.txt')
            if label_path.exists():
                shutil.copy(label_path, output_labels_dir / label_path.name)
            
            original_count += 1
        
        # Create augmented versions (only for specified splits)
        if split in splits_to_augment:
            for img_path in tqdm(image_files, desc=f"  Creating augmented versions"):
                # Load image
                img = Image.open(img_path).convert('L')
                img_array = np.array(img)
                
                # Create N augmented versions
                for aug_idx in range(num_augmentations):
                    # Apply Gaussian blur augmentation
                    img_augmented = augment_image(img_array, augmentation_probability=1.0)
                    
                    # Save augmented image with suffix
                    aug_img_name = f"{img_path.stem}_aug{aug_idx+1}{img_path.suffix}"
                    aug_img_path = output_images_dir / aug_img_name
                    Image.fromarray(img_augmented).save(aug_img_path)
                    
                    # Copy label with same suffix
                    label_path = source_labels_dir / (img_path.stem + '.txt')
                    if label_path.exists():
                        aug_label_name = f"{img_path.stem}_aug{aug_idx+1}.txt"
                        aug_label_path = output_labels_dir / aug_label_name
                        shutil.copy(label_path, aug_label_path)
                    
                    augmented_count += 1
        
        total_count = original_count + augmented_count
        aug_stats[split] = {
            'original': original_count,
            'augmented': augmented_count,
            'total': total_count
        }
        
        print(f"    ✓ Original: {original_count:,}")
        print(f"    ✓ Augmented: {augmented_count:,}")
        print(f"    ✓ Total: {total_count:,}")
    
    # Copy and update data.yaml
    source_yaml = source_dir / 'data.yaml'
    output_yaml = output_dir / 'data.yaml'
    
    with open(source_yaml, 'r') as f:
        data_yaml = yaml.safe_load(f)
    
    # Update path
    data_yaml['path'] = str(output_dir.absolute())
    
    with open(output_yaml, 'w') as f:
        yaml.dump(data_yaml, f, default_flow_style=False, sort_keys=False)
    
    # Copy data_vi.yaml if exists
    source_yaml_vi = source_dir / 'data_vi.yaml'
    if source_yaml_vi.exists():
        output_yaml_vi = output_dir / 'data_vi.yaml'
        with open(source_yaml_vi, 'r', encoding='utf-8') as f:
            data_yaml_vi = yaml.safe_load(f)
        data_yaml_vi['path'] = str(output_dir.absolute())
        with open(output_yaml_vi, 'w', encoding='utf-8') as f:
            yaml.dump(data_yaml_vi, f, default_flow_style=False, sort_keys=False, allow_unicode=True)
    
    print("\n" + "=" * 80)
    print("✓ Augmented dataset created successfully!")
    print(f"  Output directory: {output_dir.absolute()}")
    print("=" * 80)
    
    return aug_stats


# Create augmented dataset
augmented_output_dir = Path('data/preprocessed_2classes_aug')

aug_stats = create_augmented_dataset(
    source_dir=OUTPUT_DIR,
    output_dir=augmented_output_dir,
    augment_train_only=True,
    num_augmentations=1,
)


Creating Augmented Dataset
  Source: data/preprocessed_2classes
  Output: data/preprocessed_2classes_aug
  Augment train only: True
  Augmentations per image: 1

Processing TRAIN split...


  Copying originals: 100%|██████████| 3543/3543 [00:00<00:00, 8650.22it/s]
  Copying originals: 100%|██████████| 3543/3543 [00:00<00:00, 8650.22it/s]
  Creating augmented versions: 100%|██████████| 3543/3543 [01:21<00:00, 43.27it/s]
  Creating augmented versions: 100%|██████████| 3543/3543 [01:21<00:00, 43.27it/s]


    ✓ Original: 3,543
    ✓ Augmented: 3,543
    ✓ Total: 7,086

Processing VALID split...


  Copying originals: 100%|██████████| 1056/1056 [00:00<00:00, 8821.07it/s]
  Copying originals: 100%|██████████| 1056/1056 [00:00<00:00, 8821.07it/s]


    ✓ Original: 1,056
    ✓ Augmented: 0
    ✓ Total: 1,056

Processing TEST split...


  Copying originals: 100%|██████████| 492/492 [00:00<00:00, 8973.37it/s]

    ✓ Original: 492
    ✓ Augmented: 0
    ✓ Total: 492

✓ Augmented dataset created successfully!
  Output directory: /home/minhquana/workspace/project_DeepLearning/computer_vision/Abnormal-Prediction-In-Chest-X-Ray/data/preprocessed_2classes_aug





In [26]:
# Summary của augmented dataset
print("\nAugmented Dataset Summary:")
print("=" * 80)
for split in ['train', 'valid', 'test']:
    print(f"\n{split.upper()}:")
    print(f"  Original images: {aug_stats[split]['original']:,}")
    print(f"  Augmented images: {aug_stats[split]['augmented']:,}")
    print(f"  Total images: {aug_stats[split]['total']:,}")

total_original = sum(aug_stats[s]['original'] for s in ['train', 'valid', 'test'])
total_augmented = sum(aug_stats[s]['augmented'] for s in ['train', 'valid', 'test'])
total_all = sum(aug_stats[s]['total'] for s in ['train', 'valid', 'test'])

print("\n" + "=" * 80)
print(f"\nGRAND TOTAL:")
print(f"  Original: {total_original:,}")
print(f"  Augmented: {total_augmented:,}")
print(f"  Total: {total_all:,}")
print(f"  Augmentation ratio: {total_augmented/total_original*100:.1f}%")
print("=" * 80)

print("\n✓ Use this for training:")
print(f"  data.yaml: {augmented_output_dir / 'data.yaml'}")
print(f"  data_vi.yaml: {augmented_output_dir / 'data_vi.yaml'}")


Augmented Dataset Summary:

TRAIN:
  Original images: 3,543
  Augmented images: 3,543
  Total images: 7,086

VALID:
  Original images: 1,056
  Augmented images: 0
  Total images: 1,056

TEST:
  Original images: 492
  Augmented images: 0
  Total images: 492


GRAND TOTAL:
  Original: 5,091
  Augmented: 3,543
  Total: 8,634
  Augmentation ratio: 69.6%

✓ Use this for training:
  data.yaml: data/preprocessed_2classes_aug/data.yaml
  data_vi.yaml: data/preprocessed_2classes_aug/data_vi.yaml
