# Dataset Inspection: SynthBuster+

This notebook inspects the SynthBuster+ dataset loaded from `data/datasets/synthbuster-plus`.

We'll explore:
1. Dataset structure and statistics
2. Sample images from each class
3. Label distribution
4. Image properties and metadata
5. Paired sample visualization using `plot_collage()`

## Setup

In [None]:
from pathlib import Path
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from datasets import load_from_disk
from PIL import Image

# Import plot_collage from clip_cues
import sys
sys.path.insert(0, '../')
from src.clip_cues import plot_collage

# Set seaborn style
sns.set_theme(style="whitegrid")

## Load Dataset

Load the SynthBuster+ dataset from the local directory.

In [None]:
# Load dataset from disk
dataset_path = Path("../data/datasets/synthbuster-plus")

if dataset_path.exists():
    dataset = load_from_disk(str(dataset_path))
    print("✓ Dataset loaded successfully!")
    print(f"  Location: {dataset_path}")
else:
    print(f"⚠ Dataset not found at {dataset_path}")
    print("Please run: python scripts/download_dataset.py synthbuster-plus")

## Dataset Overview

Examine the dataset structure, splits, and features.

In [None]:
# Show dataset structure
print("Dataset Splits:")
print(f"{'='*60}")
for split_name in dataset.keys():
    split_data = dataset[split_name]
    print(f"\n{split_name}:")
    print(f"  Number of examples: {len(split_data):,}")
    print(f"  Features: {list(split_data.features.keys())}")
    print(f"  Feature types:")
    for feature_name, feature_type in split_data.features.items():
        print(f"    - {feature_name}: {feature_type}")

print(f"\n{'='*60}")

## Label Distribution

Analyze the distribution of real vs synthetic images.

In [None]:
# Analyze label distribution for each split
for split_name in dataset.keys():
    split_data = dataset[split_name]

    # Count labels
    labels = split_data['label']
    label_counts = Counter(labels)

    print(f"\n{split_name} - Label Distribution:")
    print(f"{'-'*40}")
    print(f"  Real (0):      {label_counts[0]:,} ({label_counts[0]/len(labels)*100:.1f}%)")
    print(f"  Synthetic (1): {label_counts[1]:,} ({label_counts[1]/len(labels)*100:.1f}%)")
    print(f"  Total:         {len(labels):,}")

    # Visualize with seaborn
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))

    # Prepare data for seaborn
    labels_list = ['Real', 'Synthetic']
    counts = [label_counts[0], label_counts[1]]
    colors = ['green', 'red']

    # Create barplot with seaborn
    sns.barplot(
        x=labels_list,
        y=counts, hue=labels_list, legend=False, palette=colors, alpha=0.7, edgecolor='black', ax=ax)

    ax.set_ylabel('Count', fontsize=12)
    ax.set_xlabel('')
    ax.set_title(f'{split_name} - Label Distribution', fontsize=14)

    # Add count labels on bars
    for i, (label, count) in enumerate(zip(labels_list, counts)):
        ax.text(i, count, f'{count:,} ({count/len(labels)*100:.1f}%)',
                ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    plt.show()

## Inspect Sample Examples

Look at the first few examples to understand the data structure.

In [None]:
# Get first example from train split
if 'train' in dataset:
    example = dataset['train'][0]

    print("Example structure:")
    print(f"{'='*60}")
    for key, value in example.items():
        if key == 'image':
            print(f"  {key}: PIL Image ({value.size}, mode={value.mode})")
        else:
            print(f"  {key}: {value}")
    print(f"{'='*60}")

In [None]:
# Get first example from train split
if 'train' in dataset:
    example = dataset['train'][-1]

    print("Example structure:")
    print(f"{'='*60}")
    for key, value in example.items():
        if key == 'image':
            print(f"  {key}: PIL Image ({value.size}, mode={value.mode})")
        else:
            print(f"  {key}: {value}")
    print(f"{'='*60}")

In [None]:
# Get first example from train split
if 'train' in dataset:
    example = dataset['train'][-1]

    print("Example structure:")
    print(f"{'='*60}")
    for key, value in example.items():
        if key == 'image':
            print(f"  {key}: PIL Image ({value.size}, mode={value.mode})")
        else:
            print(f"  {key}: {value}")
    print(f"{'='*60}")

## Visualize Sample Images

Display sample images from each class (real and synthetic).

In [None]:
def visualize_samples(dataset_split, n_samples=5, seed=42):
    """
    Visualize sample images from each class.

    Args:
        dataset_split: Dataset split to visualize
        n_samples: Number of samples per class
        seed: Random seed for reproducibility
    """
    # Set random seed
    np.random.seed(seed)

    # Get indices for each class
    real_indices = [i for i, label in enumerate(dataset_split['label']) if label == 0]
    synthetic_indices = [i for i, label in enumerate(dataset_split['label']) if label == 1]

    # Sample random indices
    real_sample_idx = np.random.choice(real_indices, min(n_samples, len(real_indices)), replace=False)
    synthetic_sample_idx = np.random.choice(synthetic_indices, min(n_samples, len(synthetic_indices)), replace=False)

    # Create figure
    fig, axes = plt.subplots(2, n_samples, figsize=(3*n_samples, 6))

    # Plot real images
    for i, idx in enumerate(real_sample_idx):
        example = dataset_split[int(idx)]
        axes[0, i].imshow(example['image'])
        axes[0, i].axis('off')

        # Add title with metadata
        title = f"Real\n"
        if 'source' in example:
            title += f"Source: {example['source']}"
        axes[0, i].set_title(title, fontsize=10, color='green', fontweight='bold')

    # Plot synthetic images
    for i, idx in enumerate(synthetic_sample_idx):
        example = dataset_split[int(idx)]
        axes[1, i].imshow(example['image'])
        axes[1, i].axis('off')

        # Add title with metadata
        title = f"Synthetic\n"
        if 'source' in example:
            title += f"Source: {example['source']}"
        axes[1, i].set_title(title, fontsize=10, color='red', fontweight='bold')

    plt.tight_layout()
    plt.show()

# Visualize samples from train split
if 'train' in dataset:
    print("Sample images from training set:")
    visualize_samples(dataset['train'], n_samples=5)

## Visualize More Samples

Display additional samples with different random seeds.

In [None]:
# Visualize with different seed
if 'train' in dataset:
    print("Different random samples:")
    visualize_samples(dataset['train'], n_samples=5, seed=123)

## Visualize Paired Samples with plot_collage()

Use the `plot_collage()` function to create a grid visualization of paired samples (real image + synthetic versions).

In [None]:
def visualize_paired_samples_collage(dataset, split_name='train', n_pairs=6, seed=42):
    """
    Visualize paired samples using plot_collage: real image + synthetic versions.

    Args:
        dataset: Dataset dict
        split_name: Dataset split to use
        n_pairs: Number of image pairs (rows) to display
        seed: Random seed for reproducibility
    """
    split_data = dataset[split_name]

    # Get real image indices
    idx_real = np.where(np.array(split_data["label"]) == 0)[0]
    image_ids = np.array(split_data["image_id"])

    # Set random seed
    np.random.seed(seed)

    # Sample random real images - np.random.choice returns actual values, not indices
    sampled_real_indices = np.random.choice(idx_real, n_pairs, replace=False)
    sampled_image_ids = set(split_data.select(sampled_real_indices.tolist())["image_id"])

    # Get all indices with these image_ids (real + synthetic versions)
    idx_to_select = np.where(np.isin(image_ids, list(sampled_image_ids)))[0]
    ds_subset = split_data.select(idx_to_select.tolist())

    print(f"Selected {len(sampled_real_indices)} real images")
    print(f"Subset size: {len(ds_subset)} (includes synthetic versions)")

    # Group images by image_id
    image_groups = defaultdict(list)

    for idx, example in enumerate(ds_subset):
        image_id = example['image_id']
        image_groups[image_id].append(idx)

    # Find groups that have both real and synthetic versions
    paired_groups = []
    for base_id, indices in image_groups.items():
        examples = [ds_subset[i] for i in indices]
        # Check if we have both real and synthetic
        labels = [ex['label'] for ex in examples]
        if 0 in labels and 1 in labels:
            paired_groups.append((base_id, indices))

    print(f"Found {len(paired_groups)} image groups with both real and synthetic versions")

    if len(paired_groups) == 0:
        print("No paired samples found in this dataset.")
        return

    # Prepare data for plot_collage
    # Each row: real image (first) + synthetic versions
    max_versions = max(len(indices) for _, indices in paired_groups)
    n_rows = len(paired_groups)
    n_cols = max_versions

    images = []
    row_labels = []
    col_labels_collected = []  # Collect column labels from first row

    for row_idx, (base_id, indices) in enumerate(paired_groups):
        examples = [ds_subset[i] for i in indices]
        # Sort: real first, then synthetic
        sorted_examples = sorted(examples, key=lambda x: x['label'])

        row_images = []
        row_sources = []

        for example in sorted_examples:
            row_images.append(example['image'])
            row_sources.append(example['source'])

        # Collect column labels from first row
        if row_idx == 0:
            col_labels_collected = row_sources.copy()

        # Pad row if needed
        while len(row_images) < n_cols:
            # Add blank image for padding
            blank_img = Image.new('RGB', (224, 224), color='white')
            row_images.append(blank_img)

        images.extend(row_images)
        row_labels.append(base_id)

    # Pad column labels if needed
    while len(col_labels_collected) < n_cols:
        col_labels_collected.append("")

    # Plot using plot_collage without individual captions
    fig, ax = plot_collage(
        images=images,
        captions=None,  # No individual captions
        row_labels=row_labels,
        col_labels=col_labels_collected,
        nrows=n_rows,
        ncols=n_cols,
        title="SynthBuster+: Paired Real (Raise1K) and Synthetic Images"
    )

    return fig, ax


# Visualize paired samples
fig, ax = visualize_paired_samples_collage(dataset, split_name='train', n_pairs=6, seed=123)
plt.tight_layout()
fig.savefig("../examples/synthbuster-plus_paired_samples_collage.png", dpi=300, bbox_inches='tight')

## Analyze Image Properties

Examine image dimensions, formats, and other properties.

In [None]:
def analyze_image_properties(dataset_split, n_samples=1000):
    """
    Analyze properties of images in the dataset.

    Args:
        dataset_split: Dataset split to analyze
        n_samples: Number of samples to analyze
    """
    # Sample indices
    n_samples = min(n_samples, len(dataset_split))
    indices = np.random.choice(len(dataset_split), n_samples, replace=False)

    widths = []
    heights = []
    labels = []
    modes = []

    print(f"Analyzing {n_samples} images...")

    for idx in indices:
        example = dataset_split[int(idx)]
        img = example['image']
        widths.append(img.size[0])
        heights.append(img.size[1])
        labels.append(example['label'])
        modes.append(img.mode)

    # Print statistics
    print(f"\nImage Statistics (based on {n_samples} samples):")
    print(f"{'='*60}")
    print(f"\nWidth:")
    print(f"  Min: {min(widths)}px")
    print(f"  Max: {max(widths)}px")
    print(f"  Mean: {np.mean(widths):.1f}px")
    print(f"  Median: {np.median(widths):.1f}px")

    print(f"\nHeight:")
    print(f"  Min: {min(heights)}px")
    print(f"  Max: {max(heights)}px")
    print(f"  Mean: {np.mean(heights):.1f}px")
    print(f"  Median: {np.median(heights):.1f}px")

    print(f"\nImage Modes:")
    mode_counts = Counter(modes)
    for mode, count in mode_counts.most_common():
        print(f"  {mode}: {count} ({count/len(modes)*100:.1f}%)")

    # Create scatterplot (height vs width)
    fig, ax = plt.subplots(figsize=(10, 8))

    # Prepare data for seaborn
    label_names = ['Real' if l == 0 else 'Synthetic' for l in labels]

    # Create scatterplot with seaborn
    sns.scatterplot(x=widths, y=heights, hue=label_names,
                    palette={'Real': 'green', 'Synthetic': 'red'},
                    alpha=0.6, s=50, edgecolor='black', linewidth=0.5, ax=ax)

    # Add mean lines
    ax.axvline(np.mean(widths), color='blue', linestyle='--', linewidth=2,
               label=f'Mean Width: {np.mean(widths):.1f}px', alpha=0.7)
    ax.axhline(np.mean(heights), color='orange', linestyle='--', linewidth=2,
               label=f'Mean Height: {np.mean(heights):.1f}px', alpha=0.7)

    ax.set_xlabel('Width (pixels)', fontsize=12)
    ax.set_ylabel('Height (pixels)', fontsize=12)
    ax.set_title('Image Dimensions: Height vs Width', fontsize=14, fontweight='bold')
    ax.grid(alpha=0.3)
    ax.legend(fontsize=10)

    plt.tight_layout()
    plt.show()

# Analyze train split
if 'train' in dataset:
    analyze_image_properties(dataset['train'], n_samples=1000)

## Analyze Source Distribution

If the dataset has a 'source' field, analyze the distribution of different sources.

In [None]:
# Check if 'source' field exists
if 'train' in dataset and 'source' in dataset['train'].features:
    for split_name in dataset.keys():
        split_data = dataset[split_name]

        # Count sources
        sources = split_data['source']
        source_counts = Counter(sources)

        print(f"\n{split_name} - Source Distribution:")
        print(f"{'='*60}")
        for source, count in source_counts.most_common():
            print(f"  {source:30}: {count:6,} ({count/len(sources)*100:5.1f}%)")

        # Visualize if there aren't too many sources
        if len(source_counts) <= 20:
            fig, ax = plt.subplots(figsize=(12, 6))

            sources_list = [s for s, _ in source_counts.most_common()]
            counts = [c for _, c in source_counts.most_common()]

            bars = ax.barh(sources_list, counts, alpha=0.7, edgecolor='black')
            ax.set_xlabel('Count', fontsize=12)
            ax.set_title(f'{split_name} - Source Distribution', fontsize=14, fontweight='bold')
            ax.grid(axis='x', alpha=0.3)

            plt.tight_layout()
            plt.show()
else:
    print("No 'source' field found in dataset.")

## Summary

Print a final summary of the dataset.

In [None]:
print("\n" + "="*60)
print("DATASET SUMMARY")
print("="*60)
print(f"\nDataset: SynthBuster+")
print(f"Location: {dataset_path}")
print(f"\nSplits:")
for split_name in dataset.keys():
    split_data = dataset[split_name]
    labels = split_data['label']
    label_counts = Counter(labels)
    print(f"  {split_name:10}: {len(split_data):6,} samples (Real: {label_counts[0]:6,}, Synthetic: {label_counts[1]:6,})")

print(f"\nFeatures: {list(dataset['train'].features.keys())}")
print("="*60)