# Dataset Curation for GouGAN using WikiArt and ArtiFact

## Setup
If you get errors running this cell, ensure the relevant libraries are installed first.

In [2]:
# Data source
import kagglehub

# OS
import os
from pathlib import Path

# Data Processing
import random

import shutil
from shutil import copy2

# Image Processing
from PIL import Image

# Parallel
from multiprocessing import Pool

# Modelling
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

# Logging
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set paths to local
PATH_WIKIART = "./datasets/wikiart"
PATH_ARTIFACT = "./datasets/artifact"

# Dataset split for WikiArt
PATH_STYLE = "./datasets/style"
PATH_STYLE_BALANCED = "./datasets/style_balanced"

# Dataset split for Artifact
PATH_REAL_FAKE_SPLIT = "./datasets/real-fake-split"
PATH_REAL_FAKE = "./datasets/real-fake"
PATH_REAL_FAKE_BALANCED = "./datasets/real-fake_balanced"

# Download Datasets
- WARNING: Downloads will take a while & take up a lot of storage

In [3]:
# Download datasets from Kaggle via Kaggle Hub
path_wikiart = kagglehub.dataset_download("steubk/wikiart")
print("Path to dataset files:", path_wikiart)

path_artifact = kagglehub.dataset_download("awsaf49/artifact-dataset")
print("Path to dataset files:", path_artifact)

Path to dataset files: /home/peter/.cache/kagglehub/datasets/steubk/wikiart/versions/1
Path to dataset files: /home/peter/.cache/kagglehub/datasets/awsaf49/artifact-dataset/versions/1


## Helper Functions

### General

In [4]:
# Copy datasets from cache to local workspace
def copy_folder(source_folder, destination_folder):
    """
    Copies a folder from the source to the destination.

    Args:
        source_folder (str): Path to the source folder.
        destination_folder (str): Path to the destination folder.

    Returns:
        str: Success or error message.
    """
    try:
        # Check if source folder exists
        if not os.path.exists(source_folder):
            return f"Source folder '{source_folder}' does not exist."

        # Check if destination folder exists
        if os.path.exists(destination_folder):
            return f"Destination folder '{destination_folder}' already exists. Choose a different path or delete it."

        # Copy the folder
        shutil.copytree(source_folder, destination_folder)
        return f"Folder successfully copied to '{destination_folder}'."
    except Exception as e:
        return f"An error occurred: {e}"

copy_folder(path_wikiart, "./datasets/wikiart")
copy_folder(path_artifact, "./datasets/artifact")

"Destination folder './datasets/artifact' already exists. Choose a different path or delete it."

In [5]:
def balance_datasets(source_folder, target_folder, categories):
    """
    Balances the datasets in train, val, and test splits for all categories.

    Args:
        source_folder (str): Path to the folder containing train/val/test splits with categories.
        target_folder (str): Path to the folder where the balanced datasets will be saved.

    Returns:
        None
    """
    splits = ['train', 'val', 'test']

    for split in splits:
        print(f"Balancing {split} dataset...")
        split_source_folder = os.path.join(source_folder, split)
        split_target_folder = os.path.join(target_folder, split)

        # Create target split directory
        os.makedirs(split_target_folder, exist_ok=True)

        # Gather class counts
        class_counts = {}
        for category in categories:
            category_folder = os.path.join(split_source_folder, category)
            if os.path.exists(category_folder):
                class_counts[category] = len([
                    f for f in os.listdir(category_folder) if os.path.isfile(os.path.join(category_folder, f))
                ])
            else:
                class_counts[category] = 0

        # Determine the target class size (minimum for undersampling, maximum for oversampling)
        target_size = min(class_counts.values())  # Change to `max(class_counts.values())` for oversampling
        print(f"Target size for balancing: {target_size}")

        # Balance each category
        for category in categories:
            source_category_folder = os.path.join(split_source_folder, category)
            target_category_folder = os.path.join(split_target_folder, category)
            os.makedirs(target_category_folder, exist_ok=True)

            # Get all files in the category
            files = [
                f for f in os.listdir(source_category_folder)
                if os.path.isfile(os.path.join(source_category_folder, f))
            ]

            # Shuffle for randomness
            random.shuffle(files)

            # Balance the dataset
            if len(files) > target_size:
                # Undersample
                balanced_files = files[:target_size]
            else:
                # Oversample
                balanced_files = files + random.choices(files, k=target_size - len(files))

            # Copy files to the target folder
            for file in balanced_files:
                copy2(os.path.join(source_category_folder, file), os.path.join(target_category_folder, file))

        print(f"{split} dataset balanced successfully.\n")


In [6]:
def get_image_statistics(dataset_folder, categories):
    """
    Computes the number of images per category (impressionist/non-impressionist)
    and per dataset split (train/val/test).

    Args:
        style_folder (str): Path to the 'style' folder containing the dataset.

    Returns:
        dict: A nested dictionary with counts per category and split.
    """
    splits = ['train', 'val', 'test']
    stats = {split: {category: 0 for category in categories} for split in splits}

    # Traverse the folder structure
    for split in splits:
        for category in categories:
            folder_path = os.path.join(dataset_folder, split, category)
            if os.path.exists(folder_path):
                num_images = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])
                stats[split][category] = num_images
            else:
                print(f"Warning: Folder '{folder_path}' does not exist.")

    return stats

### For WikiArt

In [7]:
# Split data to train, val, test sets for torch
def prepare_dataloader_folders_wikiart(wikiart_folder, style_folder):
    """
    Prepares a style folder with train/val/test subfolders for impressionist and non-impressionist images.
    
    Args:
        wikiart_folder (str): Path to the 'wikiart' folder containing art categories.
        style_folder (str): Path to the 'style' folder to be created.
    """
    random.seed(413)  # Ensure reproducibility

    # Define source and target folders
    impressionist_folder = os.path.join(wikiart_folder, "Impressionism")
    post_impressionist_folder = os.path.join(wikiart_folder, "Post_Impressionism")

    # Ensure source folders exist
    if not os.path.exists(impressionist_folder):
        raise FileNotFoundError(f"{impressionist_folder} not found.")
    if not os.path.exists(wikiart_folder):
        raise FileNotFoundError(f"{wikiart_folder} not found.")

    # Create target folders
    for split in ['train', 'val', 'test']:
        for category in ['impressionist', 'non-impressionist']:
            target_folder = os.path.join(style_folder, split, category)
            os.makedirs(target_folder, exist_ok=True)

    # Helper function to copy a single image
    def copy_image(file_path, dst_folder):
        shutil.copy(file_path, dst_folder)

    # Helper function to copy images with tqdm and parallelism
    def copy_images_parallel(src_files, dst_folder, description="Copying images"):
        with ThreadPoolExecutor() as executor:
            list(tqdm(executor.map(lambda file: copy_image(file, dst_folder), src_files), total=len(src_files), desc=description, unit="file"))

    # Process Impressionist images
    impressionist_images = list(Path(impressionist_folder).glob("*.jpg"))
    random.shuffle(impressionist_images)
    num_total = len(impressionist_images)
    train_split, val_split = int(0.8 * num_total), int(0.9 * num_total)

    copy_images_parallel(impressionist_images[:train_split], os.path.join(style_folder, 'train', 'impressionist'), "Copying Impressionist (Train)")
    copy_images_parallel(impressionist_images[train_split:val_split], os.path.join(style_folder, 'val', 'impressionist'), "Copying Impressionist (Val)")
    copy_images_parallel(impressionist_images[val_split:], os.path.join(style_folder, 'test', 'impressionist'), "Copying Impressionist (Test)")

    # Process Non-Impressionist images (exclude Post-Impressionism)
    for folder in Path(wikiart_folder).iterdir():
        if folder.is_dir() and folder.name not in ["Impressionism", "Post_Impressionism"]:
            non_impressionist_images = list(folder.glob("*.jpg"))
            random.shuffle(non_impressionist_images)
            num_total = len(non_impressionist_images)
            train_split, val_split = int(0.8 * num_total), int(0.9 * num_total)

            copy_images_parallel(non_impressionist_images[:train_split], os.path.join(style_folder, 'train', 'non-impressionist'), f"Copying Non-Impressionist (Train) from {folder.name}")
            copy_images_parallel(non_impressionist_images[train_split:val_split], os.path.join(style_folder, 'val', 'non-impressionist'), f"Copying Non-Impressionist (Val) from {folder.name}")
            copy_images_parallel(non_impressionist_images[val_split:], os.path.join(style_folder, 'test', 'non-impressionist'), f"Copying Non-Impressionist (Test) from {folder.name}")

    print(f"Data split and organization complete under {style_folder}.")

### For Artifact

In [8]:
def split_real_fake(artifact_folder: str, output_folder: str):
    """
    Splits images from artifact dataset into real and fake categories.
    Real images come from imagenet and afhq folders.
    Fake images come from all other folders.
    
    Args:
        artifact_folder (str): Path to the artifact dataset root folder
        output_folder (str): Path where real/fake folders will be created
    """
    artifact_path = Path(artifact_folder)
    output_path = Path(output_folder)
    
    # Create output directories
    real_path = output_path / 'real'
    fake_path = output_path / 'fake'
    real_path.mkdir(parents=True, exist_ok=True)
    fake_path.mkdir(parents=True, exist_ok=True)
    
    def copy_images_from_dir(src_dir, dst_dir, desc):
        """Helper function to copy all images from a directory"""
        # Get all image files recursively
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png']:
            image_files.extend(list(src_dir.rglob(ext)))
            
        # Copy files with progress bar
        for img_path in tqdm(image_files, desc=desc):
            # Create unique filename to avoid conflicts
            unique_name = f"{src_dir.name}_{img_path.name}"
            shutil.copy2(img_path, dst_dir / unique_name)
        
        return len(image_files)
    
    # Process real folders (imagenet and afhq)
    real_folders = ['imagenet', 'afhq']
    real_count = 0
    for folder in real_folders:
        folder_path = artifact_path / folder
        if folder_path.exists():
            count = copy_images_from_dir(folder_path, real_path, f"Copying {folder} (real)")
            real_count += count
            print(f"Processed {count} images from {folder}")
    
    # Process fake folders (everything else except real folders)
    fake_count = 0
    for folder_path in artifact_path.iterdir():
        if folder_path.is_dir() and folder_path.name not in real_folders + ['.git']:
            count = copy_images_from_dir(folder_path, fake_path, f"Copying {folder_path.name} (fake)")
            fake_count += count
            print(f"Processed {count} images from {folder_path.name}")
    
    print(f"\nComplete!")
    print(f"Total real images: {real_count}")
    print(f"Total fake images: {fake_count}")
    print(f"Output directory: {output_folder}")

In [9]:
# Split data to train, val, test sets for torch
def prepare_dataloader_folders_artifact(artifact_folder, real_fake_folder):
    """
    Prepares a real_fake folder with train/val/test subfolders for real and fake images.
    
    Args:
        artifact_folder (str): Path to the 'artifact' folder containing art categories.
        real_fake_folder (str): Path to the 'real_fake' folder to be created.
    """
    random.seed(413)  # Ensure reproducibility

    # Define source and target folders
    real_folder = os.path.join(artifact_folder, "real")
    fake_folder = os.path.join(artifact_folder, "fake")

    # Ensure source folders exist
    if not os.path.exists(real_folder):
        raise FileNotFoundError(f"{real_folder} not found.")
    if not os.path.exists(artifact_folder):
        raise FileNotFoundError(f"{artifact_folder} not found.")

    # Create target folders
    for split in ['train', 'val', 'test']:
        for category in ['real', 'fake']:
            target_folder = os.path.join(real_fake_folder, split, category)
            os.makedirs(target_folder, exist_ok=True)

    # Helper function to copy a single image
    def copy_image(file_path, dst_folder):
        shutil.copy(file_path, dst_folder)

    # Helper function to copy images with tqdm and parallelism
    def copy_images_parallel(src_files, dst_folder, description="Copying images"):
        with ThreadPoolExecutor() as executor:
            list(tqdm(executor.map(lambda file: copy_image(file, dst_folder), src_files), total=len(src_files), desc=description, unit="file"))

    # Process Real images
    real_images = list(Path(real_folder).glob("*.jpg"))
    random.shuffle(real_images)
    num_total = len(real_images)
    train_split, val_split = int(0.8 * num_total), int(0.9 * num_total)

    copy_images_parallel(real_images[:train_split], os.path.join(real_fake_folder, 'train', 'real'), "Copying Real (Train)")
    copy_images_parallel(real_images[train_split:val_split], os.path.join(real_fake_folder, 'val', 'real'), "Copying Real (Val)")
    copy_images_parallel(real_images[val_split:], os.path.join(real_fake_folder, 'test', 'real'), "Copying Real (Test)")

    # Process Fake images
    fake_images = list(Path(fake_folder).glob("*.jpg"))
    random.shuffle(fake_images)
    num_total = len(fake_images)
    train_split, val_split = int(0.8 * num_total), int(0.9 * num_total)

    copy_images_parallel(fake_images[:train_split], os.path.join(real_fake_folder, 'train', 'fake'), "Copying Fake (Train)")
    copy_images_parallel(fake_images[train_split:val_split], os.path.join(real_fake_folder, 'val', 'fake'), "Copying Fake (Val)")
    copy_images_parallel(fake_images[val_split:], os.path.join(real_fake_folder, 'test', 'fake'), "Copying Fake (Test)")

    print(f"Data split and organization complete under {real_fake_folder}.")

## WikiArt

In [None]:
prepare_dataloader_folders_wikiart(PATH_WIKIART, PATH_STYLE)
balance_datasets(source_folder=PATH_STYLE, target_folder=PATH_STYLE_BALANCED, categories=["impressionist", "non-impressionist"])

In [10]:
statistics = get_image_statistics(PATH_STYLE, categories=["impressionist", "non-impressionist"])
print("Style")
for split, counts in statistics.items():
    print(f"\n{split.capitalize()}:")
    for category, count in counts.items():
        print(f"  {category.capitalize()}: {count} images")

Style

Train:
  Impressionist: 10448 images
  Non-impressionist: 48885 images

Val:
  Impressionist: 1306 images
  Non-impressionist: 6186 images

Test:
  Impressionist: 1306 images
  Non-impressionist: 6195 images


In [11]:
statistics = get_image_statistics(PATH_STYLE_BALANCED, categories=["impressionist", "non-impressionist"])
print("Style (Balanced)")
for split, counts in statistics.items():
    print(f"\n{split.capitalize()}:")
    for category, count in counts.items():
        print(f"  {category.capitalize()}: {count} images")

Style (Balanced)

Train:
  Impressionist: 10448 images
  Non-impressionist: 10448 images

Val:
  Impressionist: 1306 images
  Non-impressionist: 1306 images

Test:
  Impressionist: 1306 images
  Non-impressionist: 1306 images


# Artifact

In [None]:
split_real_fake(PATH_ARTIFACT, PATH_REAL_FAKE_SPLIT)
prepare_dataloader_folders_artifact(PATH_REAL_FAKE_SPLIT, PATH_REAL_FAKE)
balance_datasets(source_folder=PATH_REAL_FAKE, target_folder=PATH_REAL_FAKE_BALANCED, categories = ["real", "fake"])

In [12]:
statistics = get_image_statistics(PATH_REAL_FAKE, categories=["real", "fake"])
print("Real-Fake")
for split, counts in statistics.items():
    print(f"\n{split.capitalize()}:")
    for category, count in counts.items():
        print(f"  {category.capitalize()}: {count} images")

Real-Fake

Train:
  Real: 102976 images
  Fake: 779780 images

Val:
  Real: 12872 images
  Fake: 97472 images

Test:
  Real: 12873 images
  Fake: 97473 images


In [13]:
statistics = get_image_statistics(PATH_REAL_FAKE_BALANCED, categories=["real", "fake"])
print("Real-Fake (Balanced)")
for split, counts in statistics.items():
    print(f"\n{split.capitalize()}:")
    for category, count in counts.items():
        print(f"  {category.capitalize()}: {count} images")

Real-Fake (Balanced)

Train:
  Real: 102976 images
  Fake: 102976 images

Val:
  Real: 12872 images
  Fake: 12872 images

Test:
  Real: 12873 images
  Fake: 12873 images


## Tests
- Loading each dataset using PyTorch dataloaders
- Perform augmentations

In [None]:
# Testing Torch Augmentation and Dataloaders
train_transforms = transforms.Compose([
    transforms.Resize((200, 200)),           # Resize to 200x200
    transforms.RandomHorizontalFlip(p=0.5), # Augment with flipping
    transforms.RandomRotation(degrees=10),  # Augment with slight rotation
    transforms.ToTensor(),                  # Convert to tensor
    transforms.Normalize(                    # Normalize using ImageNet stats
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((200, 200)),          # Resize to 200x200
    transforms.ToTensor(),                  # Convert to tensor
    transforms.Normalize(                    # Normalize using ImageNet stats
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Paths to dataset splits
train_path = "./datasets/style_balanced/train"
val_path = "./datasets/style_balanced/val"
test_path = "./datasets/style_balanced/test"

# Datasets
train_dataset = ImageFolder(root=train_path, transform=train_transforms)
val_dataset = ImageFolder(root=val_path, transform=val_test_transforms)
test_dataset = ImageFolder(root=test_path, transform=val_test_transforms)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define transformations (resize, normalize, etc.)
transforms_pipeline = transforms.Compose([
    transforms.Resize((200, 200)),          # Resize to 200x200
    transforms.ToTensor(),                  # Convert to tensor
    transforms.Normalize(                    # Normalize using ImageNet stats
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Load the dataset
dataset_path = "./datasets/style_balanced/train"
dataset = ImageFolder(root=dataset_path, transform=transforms_pipeline)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Class-to-index mapping
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}  # e.g., {0: 'impressionist', 1: 'non-impressionist'}

# Helper function to denormalize and convert tensor to image
def denormalize_and_convert(image_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    image_tensor = image_tensor * std[:, None, None] + mean[:, None, None]  # Denormalize
    image = image_tensor.permute(1, 2, 0).numpy()  # Convert to HWC format for Matplotlib
    return image.clip(0, 1)  # Ensure pixel values are in range [0, 1]

# Visualize examples
def visualize_examples(dataloader, idx_to_class, num_samples=4):
    """
    Visualizes a few examples from each category in the dataset.

    Args:
        dataloader (DataLoader): PyTorch DataLoader for the dataset.
        idx_to_class (dict): Mapping from class index to class label.
        num_samples (int): Number of examples to display per category.

    Returns:
        None
    """
    examples_per_category = {class_name: [] for class_name in idx_to_class.values()}
    for images, labels in dataloader:
        for image, label in zip(images, labels):
            class_name = idx_to_class[label.item()]
            if len(examples_per_category[class_name]) < num_samples:
                examples_per_category[class_name].append(image)
        if all(len(images) >= num_samples for images in examples_per_category.values()):
            break

    # Plot examples
    for category, images in examples_per_category.items():
        print(f"Category: {category}")
        fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
        for i, img in enumerate(images):
            axes[i].imshow(denormalize_and_convert(img))
            axes[i].axis("off")
        plt.show()

# Call the function
visualize_examples(dataloader, idx_to_class, num_samples=10)

## Experimental
- **Hybrid Method:** Resize images to $1000 \times 1000$, then create $200 \times 200$ patches

In [None]:
def create_patches_for_image(args):
    """
    Resizes an image to 1000x1000 and splits it into 200x200 patches.
    We call this the hybrid approach.
    
    Args:
        args (tuple): Contains (image_path, target_folder, patch_size, overlap)
    """
    image_path, target_folder, patch_size, overlap = args
    
    try:
        # Create target directory if it doesn't exist
        os.makedirs(target_folder, exist_ok=True)
        
        # Open and resize image to 1000x1000
        with Image.open(image_path) as image:
            # Convert to RGB if needed
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize to 1000x1000 with antialiasing
            image = image.resize((1000, 1000), Image.Resampling.LANCZOS)
            
            # Calculate steps for patches
            step = patch_size - overlap
            num_patches = ((1000 - patch_size) // step + 1) ** 2
            
            # Create patches
            patch_count = 0
            for i in range(0, 1000 - patch_size + 1, step):
                for j in range(0, 1000 - patch_size + 1, step):
                    patch_count += 1
                    patch = image.crop((j, i, j + patch_size, i + patch_size))
                    
                    # Create new filename with patch number
                    original_name = os.path.splitext(os.path.basename(image_path))[0]
                    patch_name = f"{original_name}_patch_{patch_count:03d}.jpg"
                    
                    # Save patch with high quality
                    patch.save(
                        os.path.join(target_folder, patch_name),
                        'JPEG',
                        quality=100,
                        optimize=True
                    )
                    
    except Exception as e:
        print(f"Error processing {image_path}: {e}")

def process_style_folder(root_path, style_folder, patch_size=200, overlap=0, num_workers=4):
    """
    Process all images in a style folder (impressionist or non-impressionist).
    
    Args:
        root_path (str): Base path containing style folders
        style_folder (str): Name of the style folder
        patch_size (int): Size of patches
        overlap (int): Overlap between patches
        num_workers (int): Number of parallel workers
    """
    source_folder = os.path.join(root_path, style_folder)
    target_folder = source_folder  # Save in the same folder
    
    # Get all image files
    image_files = [
        os.path.join(source_folder, f)
        for f in os.listdir(source_folder)
        if f.lower().endswith(('.jpg', '.jpeg', '.png'))
    ]
    
    # Prepare arguments for parallel processing
    args = [
        (image_path, target_folder, patch_size, overlap)
        for image_path in image_files
    ]
    
    # Process images in parallel with progress bar
    with Pool(num_workers) as pool:
        list(tqdm(
            pool.imap(create_patches_for_image, args),
            total=len(args),
            desc=f"Processing {style_folder}"
        ))

# Base path
train_path = "./datasets/style_balanced/train"

# Process both impressionist and non-impressionist folders
style_folders = ['impressionist', 'non-impressionist']

for style_folder in style_folders:
    print(f"\nProcessing {style_folder} images...")
    process_style_folder(
        root_path=train_path,
        style_folder=style_folder,
        patch_size=200,
        overlap=0,  # 50% would be 100=200*0.5 overlap
        num_workers=4
    )
    print(f"Completed processing {style_folder} images")