In [None]:
from datasets import load_from_disk
from pathlib import Path
import shutil
import os
from PIL import Image as PILImage, ImageFile
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm
import time
import numpy as np

# Pillow setup
ImageFile.LOAD_TRUNCATED_IMAGES = True
PILImage.MAX_IMAGE_PIXELS = None  # disable decompression bomb check, we it ourseleves
MAX_PIXELS = 178_956_970

# User settings
where_you_saved_data = r"/scratch/szp2fv/ID_AI_Project/DS6050_Ai_Detection"
SCALE_FACTOR = 0.1 # scaling the super large images
MAX_WORKERS = 8  # parallelize over 8 cores, breaks with more
LOG_INTERVAL_IMAGES = 1000  # print every N images, this isnt working :/

# Worker function
def process_image(idx, row, out_folder):
    try:
        img_data = row["image"]
        fname = row.get("filename", f"{idx:06d}.jpg")
        fname = Path(fname).name
        ext = Path(fname).suffix.lower() or ".jpg"
        if ext not in [".jpg", ".jpeg", ".png"]:
            ext = ".jpg"
        fname = f"{Path(fname).stem}{ext}"
        dst_path = out_folder / fname

        # Handle PIL image
        if hasattr(img_data, "save"):
            img = img_data
        # Handle numpy array
        elif isinstance(img_data, np.ndarray):
            img = PILImage.fromarray(img_data)
        # Handle bytes
        elif isinstance(img_data, bytes):
            with open(dst_path, "wb") as f:
                f.write(img_data)
            return "saved", 0
        else:
            return "skipped", 0

        resized = 0
        if img.width * img.height > MAX_PIXELS:
            new_size = (int(img.width * SCALE_FACTOR), int(img.height * SCALE_FACTOR))
            img = img.resize(new_size, PILImage.Resampling.LANCZOS)
            resized = 1

        if img.mode in ("P", "RGBA"):
            img = img.convert("RGB") if ext != ".png" else img.convert("RGBA")

        img_format = "PNG" if ext == ".png" else "JPEG"
        save_kwargs = {"format": img_format}
        if img_format == "JPEG":
            save_kwargs.update({"quality": 95, "optimize": True})

        img.save(dst_path, **save_kwargs)
        return "saved", resized
    except Exception as e:
        return "skipped", 0

# Main reconstruction
def reconstruct_dataset():
    hf_local_path = Path(where_you_saved_data)
    splits = ["train", "validation"]
    types = ["real", "fake"]

    overall_start = time.time()

    for split in splits:
        temp_split_folder = hf_local_path / f"temp_{split}"
        temp_split_folder.mkdir(exist_ok=True, parents=True)
        print(f"\n=== Starting reconstruction for {split} split ===", flush=True)

        for dtype in types:
            folder_start = time.time()
            arrow_path = hf_local_path / split / dtype
            if not arrow_path.exists():
                print(f"Skipping: {arrow_path} does not exist.", flush=True)
                continue

            print(f"Loading Arrow dataset from {arrow_path} ...", flush=True)
            ds = load_from_disk(str(arrow_path))

            out_folder = temp_split_folder / dtype
            out_folder.mkdir(parents=True, exist_ok=True)

            total_imgs = len(ds)
            print(f"Processing {total_imgs} images using {MAX_WORKERS} workers...", flush=True)

            saved_count = 0
            skipped_count = 0
            resized_count = 0
            start_time = time.time()

            # Submit all images to the executor
            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
                futures = [executor.submit(process_image, idx, row, out_folder) for idx, row in enumerate(ds)]

                for i, f in enumerate(tqdm(as_completed(futures), total=len(futures),
                                           desc=f"{split}/{dtype}", unit="img")):
                    status, resized = f.result()
                    if status == "saved":
                        saved_count += 1
                    else:
                        skipped_count += 1
                    resized_count += resized

                    # Print periodic progress
                    if (i + 1) % LOG_INTERVAL_IMAGES == 0:
                        elapsed = time.time() - start_time
                        speed = (i + 1) / elapsed if elapsed > 0 else 0
                        remaining = total_imgs - (i + 1)
                        eta_minutes = remaining / speed / 60 if speed > 0 else 0
                        print(f"[{split}/{dtype}] Processed {i+1}/{total_imgs} images "
                              f"(speed: {speed:.2f} img/s, ETA: {eta_minutes:.2f} min, resized: {resized_count}, skipped: {skipped_count})",
                              flush=True)

            folder_duration = time.time() - folder_start
            print(f"\nFinished {split}/{dtype}: saved {saved_count}, resized {resized_count}, skipped {skipped_count}.", flush=True)
            print(f"Time taken: {folder_duration / 60:.2f} minutes\n", flush=True)

            # Remove Arrow folder to save space
            shutil.rmtree(arrow_path, ignore_errors=True)

        # Cleanup split folder
        split_folder = hf_local_path / split
        if split_folder.exists():
            shutil.rmtree(split_folder, ignore_errors=True)
        os.rename(temp_split_folder, hf_local_path / split)

    overall_duration = time.time() - overall_start
    print(f"\nDataset reconstruction complete! Total time: {overall_duration / 60:.2f} minutes", flush=True)

# ==============================
# Run
# ==============================
reconstruct_dataset()



=== Starting reconstruction for train split ===
Loading Arrow dataset from /scratch/szp2fv/ID_AI_Project/DS6050_Ai_Detection/train/real ...


Loading dataset from disk:   0%|          | 0/28 [00:00<?, ?it/s]

Processing 57600 images using 8 workers...
