In this notebook, we transform raw datasets to parquet format to enable faster loading speed during training and evaluation.

The raw format of released datasets is as follows:
```python
# train set
/train/real/...
/train/fake/...
/train/masks/...
# valid set
/valid/real/...
/valid/fake/...
/valid/masks/...
```

In [15]:
import os
from datasets import Dataset, DatasetDict
from datasets import Features, Image, Value
from typing import List, Optional


def load_images_from_dir(directory: str) -> List[str]:
    return [
        os.path.join(directory, fname)
        for fname in os.listdir(directory)
        if fname.endswith(("jpg", "jpeg", "png", "tif"))
    ]


def create_split(root_dir: str, split: str) -> Optional[Dataset]:
    fake_dir = os.path.join(root_dir, split, "fake")
    real_dir = os.path.join(root_dir, split, "real")

    if all(not os.path.isdir(p) for p in [fake_dir, real_dir]):
        return None

    print(f"Split: {split},", end=" ")
    fake_images, real_images = [], []
    if os.path.isdir(fake_dir):
        fake_images = load_images_from_dir(fake_dir)
        print(f"Fake images: {len(fake_images)}", end="")
    if os.path.isdir(real_dir):
        real_images = load_images_from_dir(real_dir)
        print(f", Real images: {len(real_images)}", end="")
    print()

    return Dataset.from_dict(
        {
            "path": fake_images + real_images,
            "image": fake_images + real_images,
        },
        features=Features(
            {"path": Value(dtype="string"), "image": Image()}
        ),
    )


def create_dataset(root_dir: str) -> DatasetDict:
    return DatasetDict(
        {
            split: d
            for split in ["train", "valid", "test"]
            if (d := create_split(root_dir, split)) is not None
        }
    )


# replace with your own dataset path
root_dir = "../dataset/"
save_dir = "./data_output"

We merge `real/` and `fake/` into `images` column for simplity. A image is real if there is no corresponding mask.

In [16]:
dataset = create_dataset(root_dir)
dataset

Split: train, Fake images: 16880, Real images: 19284
Split: valid, Fake images: 2110, Real images: 2411
Split: test, Fake images: 2110, Real images: 2411


DatasetDict({
    train: Dataset({
        features: ['path', 'image'],
        num_rows: 36164
    })
    valid: Dataset({
        features: ['path', 'image'],
        num_rows: 4521
    })
    test: Dataset({
        features: ['path', 'image'],
        num_rows: 4521
    })
})

Then save processed datasets to parquet.

In [17]:
os.makedirs(save_dir, exist_ok=True)
for split in dataset:
    dataset[split].to_parquet(os.path.join(save_dir, f"{split}.parquet"))
    print(f"Saved {split} split to {save_dir}/{split}.parquet")

Creating parquet from Arrow format: 100%|█████████████| 140/140 [00:00<00:00, 5950.33ba/s]


Saved train split to ./data_output/train.parquet


Creating parquet from Arrow format: 100%|███████████████| 18/18 [00:00<00:00, 5703.09ba/s]


Saved valid split to ./data_output/valid.parquet


Creating parquet from Arrow format: 100%|███████████████| 18/18 [00:00<00:00, 1376.96ba/s]

Saved test split to ./data_output/test.parquet





Load from processed datasets to do whatever you want.

In [18]:
import os
from datasets import load_dataset

trainset = load_dataset("parquet", data_dir=save_dir, split="train")
trainset

Generating train split: 36164 examples [00:00, 920087.17 examples/s]
Generating validation split: 4521 examples [00:00, 1195313.19 examples/s]
Generating test split: 4521 examples [00:00, 1254130.18 examples/s]


Dataset({
    features: ['path', 'image'],
    num_rows: 36164
})

Since the forged components are usually smaller in proportion compared to the real ones, this leads to class imbalance.
For optimal training performance, hyper parameters such as `pixel_forge_weight` and `cls_forge_weight` in `src.loupe.configuration_loupe.LoupeConfig` must be appropriately configured. These parameters control the weights of forged pixels and forged images.

Once suitable parameters are found using the following code snippet, you can set them in `configs/model/cls.yaml` or `configs/model/seg.yaml`.


In [19]:
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm

cls_forge_weight: float  # the ratio of forged images to total images.
# the ratio of forged patches to total patches across all images.
patch_forge_weight: float
# the ratio of forged pixels to total pixels across fake images.
pixel_forge_weight: float

num_subset_samples = min(5000, len(trainset))
subset = trainset.shuffle().select(range(num_subset_samples))
image_size, patch_size = 336, 14


def compute_mask_stats(example):

    if example["mask"] is None:
        return {
            "is_forge": 0,
            "forge_pixel_sum": 0.0,
            "total_pixel_count": 0,
            "forge_patch_sum": 0.0,
        }

    mask = example["mask"].convert("L").resize((image_size, image_size), Image.NEAREST)
    mask_np = np.array(mask, dtype=np.float32)

    if mask_np.max() != mask_np.min():
        mask_np = (mask_np - mask_np.min()) / (mask_np.max() - mask_np.min())
    else:
        mask_np[:] = 0.0

    forged_pixel_sum = mask_np.sum()
    total_pixels = mask_np.size

    reshaped = mask_np.reshape(
        image_size // patch_size, patch_size, image_size // patch_size, patch_size
    )
    patches = reshaped.transpose(0, 2, 1, 3)
    forged_patch_sum = (patches != 0).sum(axis=(2, 3)) / (patch_size * patch_size)
    forged_patch_sum = forged_patch_sum.sum()

    return {
        "is_forge": 1,
        "forge_pixel_sum": forged_pixel_sum,
        "total_pixel_count": total_pixels,
        "forge_patch_sum": forged_patch_sum,
    }


processed = subset.map(compute_mask_stats, num_proc=8, desc="Computing mask stats")

num_forge_images = sum(processed["is_forge"])
num_forge_pixels = sum(processed["forge_pixel_sum"])
num_total_pixels = sum(processed["total_pixel_count"])
num_forge_patches = sum(processed["forge_patch_sum"])
num_total_patches = len(processed) * (image_size // patch_size) ** 2

cls_forge_weight = 1 - num_forge_images / len(processed)
patch_forge_weight = 1 - num_forge_patches / num_total_patches
pixel_forge_weight = 1 - num_forge_pixels / num_total_pixels

print("cls_forge_weight:", cls_forge_weight)
print("patch_forge_weight:", patch_forge_weight)
print("pixel_forge_weight:", pixel_forge_weight)

Computing mask stats (num_proc=8):   0%|                  | 0/5000 [00:00<?, ? examples/s]


KeyError: 'mask'