In this notebook, we transform raw datasets to parquet format to enable faster loading speed during training and evaluation.

The raw format of released datasets is as follows:
```python
# train set
/train/real/...
/train/fake/...
/train/masks/...
# valid set
/valid/real/...
/valid/fake/...
/valid/masks/...
```

In [13]:
import os
from datasets import Dataset, DatasetDict
from datasets import Features, Image, Value
from typing import List, Optional


def load_images_from_dir(directory: str) -> List[str]:
    return [
        os.path.join(directory, fname)
        for fname in os.listdir(directory)
        if fname.endswith(("jpg", "jpeg", "png"))
    ]


def create_split(root_dir: str, split: str) -> Optional[Dataset]:
    try:
        if split == "test":
            image_dir = os.path.join(root_dir, split)
            images = load_images_from_dir(image_dir)
            print(f"Split: {split}, Images: {len(images)}")
            return Dataset.from_dict(
                {"path": images, "image": images},
                features=Features({"path": Value(dtype="string"), "image": Image()}),
            )
        fake_dir = os.path.join(root_dir, split, "fake")
        masks_dir = os.path.join(root_dir, split, "masks")
        real_dir = os.path.join(root_dir, split, "real")

        fake_images = load_images_from_dir(fake_dir)
        mask_images = load_images_from_dir(masks_dir)
        real_images = load_images_from_dir(real_dir)
        print(
            f"Split: {split}, Fake images: {len(fake_images)}, Masks: {len(mask_images)}, Real images: {len(real_images)}"
        )
        assert len(fake_images) == len(mask_images)

        return Dataset.from_dict(
            {
                "path": fake_images + real_images,
                "image": fake_images + real_images,
                "mask": mask_images + [None] * len(real_images),
            },
            features=Features(
                {"path": Value(dtype="string"), "image": Image(), "mask": Image()}
            ),
        )
    except Exception as e:
        print(f"Error processing split {split}: {e}")
        return None


def create_dataset(root_dir: str) -> DatasetDict:
    return DatasetDict(
        {
            split: d
            for split in ["train", "valid", "test"]
            if (d := create_split(root_dir, split)) is not None
        }
    )


root_dir = "/gemini/space/lye/track1"
save_dir = "/gemini/space/jyc/track1"

We merge `real/` and `fake/` into `images` column for simplity. A image is real if there is no corresponding mask.

In [14]:
dataset = create_dataset(root_dir)
dataset

Split: train, Fake images: 798831, Masks: 798831, Real images: 156100
Split: valid, Fake images: 199708, Masks: 199708, Real images: 39025
Split: test, Images: 222847


DatasetDict({
    train: Dataset({
        features: ['path', 'image', 'mask'],
        num_rows: 954931
    })
    valid: Dataset({
        features: ['path', 'image', 'mask'],
        num_rows: 238733
    })
    test: Dataset({
        features: ['path', 'image'],
        num_rows: 222847
    })
})

Then save processed datasets to parquet.

In [15]:
os.makedirs(save_dir, exist_ok=True)
for split in dataset:
    dataset[split].to_parquet(os.path.join(save_dir, f"{split}.parquet"))
    print(f"Saved {split} split to {save_dir}/{split}.parquet")

Creating parquet from Arrow format: 100%|██████████| 9550/9550 [00:01<00:00, 4956.62ba/s]


Saved train split to /gemini/space/jyc/track1/train.parquet


Creating parquet from Arrow format: 100%|██████████| 2388/2388 [00:00<00:00, 4971.15ba/s]


Saved valid split to /gemini/space/jyc/track1/valid.parquet


Creating parquet from Arrow format: 100%|██████████| 2229/2229 [00:00<00:00, 7133.34ba/s]


Saved test split to /gemini/space/jyc/track1/test.parquet


Load from processed datasets to do whatever you want.

In [None]:
import os
from datasets import load_dataset

validset = load_dataset(
    "parquet", data_files=os.path.join(save_dir, "valid.parquet")
)
validset