<a href="https://colab.research.google.com/github/hanvocado/pneumonia_detection/blob/linh/src/preprocess/split_dataset_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
import shutil
import random
import numpy as np
from glob import glob
from tqdm import tqdm
import cv2

# THIẾT LẬP MÔI TRƯỜNG & ĐƯỜNG DẪN
current_directory = os.getcwd()

while not current_directory.endswith('pneumonia_detection'):
    parent = os.path.dirname(current_directory)
    if parent == current_directory:
        print("⚠️ Không tìm thấy thư mục 'pneumonia_detection'. Dùng thư mục hiện tại.")
        break
    current_directory = parent

root_directory = current_directory
os.chdir(root_directory)
sys.path.insert(0, root_directory)

INPUT_DIR = os.path.join(root_directory, 'data_processed')
OUTPUT_DIR = os.path.join(root_directory, 'data_processed_new')

TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# IN BẢNG THỐNG KÊ

def print_summary_table(title):

    print(f"\n{'='*45}")
    print(f"{title:^45}")
    print(f"{'='*45}")

    header = f"{'Split':<10} {'NORMAL':<10} {'PNEUMONIA':<10} {'Total':<10}"
    print(header)
    print("-" * 45)

    splits = ['train', 'val', 'test']
    labels = ['NORMAL', 'PNEUMONIA']

    for split in splits:
        counts = []
        row_total = 0
        for label in labels:
            path = os.path.join(OUTPUT_DIR, split, label)
            files = glob(os.path.join(path, '*'))
            count = len(files)
            counts.append(count)
            row_total += count

        split_name = split.capitalize()
        row_str = f"{split_name:<10} {counts[0]:<10,} {counts[1]:<10,} {row_total:<10,}"
        print(row_str)
    print("-" * 45 + "\n")

# CHIA DỮ LIỆU

def stratified_split():
    if not os.path.exists(INPUT_DIR):
        print(f"Lỗi: Không tìm thấy folder nguồn '{INPUT_DIR}'")
        return False

    if os.path.exists(OUTPUT_DIR):
        shutil.rmtree(OUTPUT_DIR)

    for split in ['train', 'val', 'test']:
        for label in ['NORMAL', 'PNEUMONIA']:
            os.makedirs(os.path.join(OUTPUT_DIR, split, label), exist_ok=True)

    print("Đang thực hiện chia dữ liệu...", end="\r")

    for label in ['NORMAL', 'PNEUMONIA']:
        # Gom ảnh
        files = glob(os.path.join(INPUT_DIR, '*', label, '*.jpeg')) + \
                glob(os.path.join(INPUT_DIR, '*', label, '*.jpg')) + \
                glob(os.path.join(INPUT_DIR, label, '*.jpeg')) + \
                glob(os.path.join(INPUT_DIR, label, '*.jpg'))

        files = list(set(files))
        total_files = len(files)

        if total_files == 0: continue

        random.shuffle(files)
        train_end = int(total_files * TRAIN_RATIO)
        val_end = train_end + int(total_files * VAL_RATIO)

        train_files = files[:train_end]
        val_files = files[train_end:val_end]
        test_files = files[val_end:]

        for f in train_files: shutil.copy2(f, os.path.join(OUTPUT_DIR, 'train', label))
        for f in val_files:   shutil.copy2(f, os.path.join(OUTPUT_DIR, 'val', label))
        for f in test_files:  shutil.copy2(f, os.path.join(OUTPUT_DIR, 'test', label))

    return True

#TĂNG CƯỜNG DỮ LIỆU

def augment_image_logic(image_path, save_path):
    try:
        img = cv2.imread(image_path, 0)
        if img is None: return

        choice = random.randint(0, 1)
        if choice == 0:
            aug_img = cv2.flip(img, 1) # Flip ngang
        else:
            rows, cols = img.shape
            tx = random.randint(-10, 10) #Translation
            ty = random.randint(-10, 10)
            M = np.float32([[1, 0, tx], [0, 1, ty]])
            aug_img = cv2.warpAffine(img, M, (cols, rows))

        cv2.imwrite(save_path, aug_img)
    except:
        pass

def balance_train_data():
    train_dir = os.path.join(OUTPUT_DIR, 'train')

    normal_path = os.path.join(train_dir, 'NORMAL')
    pneumonia_path = os.path.join(train_dir, 'PNEUMONIA')

    normal_files = glob(os.path.join(normal_path, '*'))
    pneumonia_files = glob(os.path.join(pneumonia_path, '*'))

    n_count = len(normal_files)
    p_count = len(pneumonia_files)
    target_count = max(n_count, p_count)

    if n_count < target_count:
        needed = target_count - n_count
        print(f"Đang tăng cường lớp NORMAL (Sinh thêm {needed} ảnh)...")

        for i in tqdm(range(needed), desc="Augmenting", unit="img"):
            src_img = random.choice(normal_files)
            new_filename = f"aug_{i}_{os.path.basename(src_img)}"
            dst_path = os.path.join(normal_path, new_filename)
            augment_image_logic(src_img, dst_path)

if __name__ == "__main__":
    success = stratified_split()
    if success:
        print_summary_table("KẾT QUẢ SAU KHI CHIA LẠI DỮ LIỆU")
        balance_train_data()
        print_summary_table("KẾT QUẢ SAU KHI TĂNG CƯỜNG")

        print(f"Hoàn tất!")
    else:
        print("Quy trình thất bại.")