# Image Grid Showcase

Generate a 1x5 comparison grid from a directory of images.

- First image is the **reference** (separated by extra space)
- Remaining 4 images are **comparisons** (equidistant spacing)
- Light grey background

In [None]:
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import os

# Setup
current_dir = Path.cwd()
OUTPUT_DIR = current_dir.parent / "data/grids"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("Setup complete!")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
def load_images_from_directory(directory, max_images=5):
    """
    Load images from a directory.
    Supports: .png, .jpg, .jpeg, .webp
    """
    directory = Path(directory)
    extensions = {'.png', '.jpg', '.jpeg', '.webp'}
    
    image_files = sorted([
        f for f in directory.iterdir() 
        if f.suffix.lower() in extensions
    ])
    
    if len(image_files) < max_images:
        print(f"Warning: Found only {len(image_files)} images, expected {max_images}")
    
    images = []
    filenames = []
    for img_path in image_files[:max_images]:
        img = Image.open(img_path).convert('RGB')
        images.append(img)
        filenames.append(img_path.name)
        
    return images, filenames


def create_comparison_grid(images, filenames=None, 
                           gap_after_first=60, 
                           regular_gap=20,
                           padding=30,
                           bg_color=(245, 245, 245),
                           show_labels=True,
                           label_height=30):
    """
    Create a 1x5 comparison grid with extra space after the first image.
    
    Args:
        images: List of PIL Images
        filenames: Optional list of filenames for labels
        gap_after_first: Pixels between image 1 and 2 (larger gap)
        regular_gap: Pixels between images 2-5
        padding: Padding around the entire grid
        bg_color: Background color as RGB tuple
        show_labels: Whether to show filename labels
        label_height: Height reserved for labels
    
    Returns:
        PIL Image of the grid
    """
    if not images:
        raise ValueError("No images provided")
    
    n = len(images)
    
    # Get dimensions (assume all images same size, or resize to first)
    target_size = images[0].size
    resized_images = []
    for img in images:
        if img.size != target_size:
            img = img.resize(target_size, Image.Resampling.LANCZOS)
        resized_images.append(img)
    
    img_width, img_height = target_size
    
    # Calculate total width
    # First image + gap_after_first + (n-1 images with regular gaps between them)
    if n == 1:
        total_width = padding * 2 + img_width
    else:
        total_width = (padding * 2 + 
                       img_width +  # First image
                       gap_after_first +  # Large gap
                       (n - 1) * img_width +  # Remaining images
                       (n - 2) * regular_gap)  # Regular gaps between remaining
    
    # Calculate total height
    label_space = label_height if show_labels else 0
    total_height = padding * 2 + img_height + label_space
    
    # Create canvas
    canvas = Image.new('RGB', (total_width, total_height), bg_color)
    
    # Place images
    x = padding
    y = padding
    
    for i, img in enumerate(resized_images):
        canvas.paste(img, (x, y))
        
        # Move x position
        x += img_width
        if i == 0:
            x += gap_after_first  # Large gap after first
        elif i < n - 1:
            x += regular_gap  # Regular gap for others
    
    return canvas


def create_grid_with_matplotlib(images, filenames=None,
                                 gap_after_first=0.15,
                                 regular_gap=0.02,
                                 bg_color='#F5F5F5',
                                 figsize=(20, 5),
                                 show_labels=True,
                                 title=None):
    """
    Create grid using matplotlib for more control and display.
    
    Args:
        images: List of PIL Images
        filenames: Optional list of filenames for labels  
        gap_after_first: Relative width of gap after first image (0-1)
        regular_gap: Relative width of regular gaps (0-1)
        bg_color: Background color
        figsize: Figure size
        show_labels: Show filename labels below images
        title: Optional title for the grid
    """
    n = len(images)
    if n == 0:
        print("No images to display")
        return None
    
    # Calculate width ratios
    # [img1, big_gap, img2, gap, img3, gap, img4, gap, img5]
    width_ratios = [1]  # First image
    if n > 1:
        width_ratios.append(gap_after_first)  # Big gap
        for i in range(1, n):
            width_ratios.append(1)  # Image
            if i < n - 1:
                width_ratios.append(regular_gap)  # Small gap
    
    # Number of columns (images + gaps)
    ncols = len(width_ratios)
    
    fig, axes = plt.subplots(1, ncols, figsize=figsize,
                             gridspec_kw={'width_ratios': width_ratios})
    fig.patch.set_facecolor(bg_color)
    
    if ncols == 1:
        axes = [axes]
    
    # Track which axes get images
    img_idx = 0
    for i, ax in enumerate(axes):
        ax.set_facecolor(bg_color)
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Determine if this is an image slot or gap
        if n == 1:
            is_image_slot = (i == 0)
        else:
            # Pattern: [img, gap, img, gap, img, gap, img, gap, img]
            # After first image, alternates: big_gap, img, gap, img, gap...
            if i == 0:
                is_image_slot = True
            elif i == 1:
                is_image_slot = False  # Big gap
            else:
                is_image_slot = (i % 2 == 0)  # Even indices after gap are images
        
        if is_image_slot and img_idx < len(images):
            ax.imshow(images[img_idx])
            if show_labels and filenames and img_idx < len(filenames):
                # Truncate long filenames
                label = filenames[img_idx]
                if len(label) > 20:
                    label = label[:17] + '...'
                ax.set_xlabel(label, fontsize=9, color='#666666')
            
            # Add "Reference" label to first image
            if img_idx == 0:
                ax.set_title('Reference', fontsize=10, fontweight='bold', color='#333333')
            
            img_idx += 1
            for spine in ax.spines.values():
                spine.set_visible(False)
        else:
            # This is a gap - make it invisible
            ax.axis('off')
    
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    return fig

In [None]:
# Configuration widgets

dir_input = widgets.Text(
    value='',
    description='Image Directory:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px'),
    placeholder='Enter full path to directory with 5 images'
)

gap_slider = widgets.FloatSlider(
    value=0.15,
    min=0.05,
    max=0.4,
    step=0.01,
    description='Gap after reference:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

regular_gap_slider = widgets.FloatSlider(
    value=0.02,
    min=0.0,
    max=0.1,
    step=0.005,
    description='Regular gap:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

show_labels_checkbox = widgets.Checkbox(
    value=True,
    description='Show filenames',
    style={'description_width': 'initial'}
)

title_input = widgets.Text(
    value='',
    description='Title (optional):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

generate_button = widgets.Button(
    description='Generate Grid',
    button_style='success',
    icon='image',
    layout=widgets.Layout(width='150px')
)

save_button = widgets.Button(
    description='Save Grid',
    button_style='primary',
    icon='save',
    layout=widgets.Layout(width='150px')
)

output_area = widgets.Output()

# Store current figure for saving
current_fig = None
current_dir_name = None

def on_generate(b):
    global current_fig, current_dir_name
    
    with output_area:
        output_area.clear_output(wait=True)
        
        directory = dir_input.value.strip()
        if not directory:
            print("Please enter a directory path")
            return
        
        directory = Path(directory)
        if not directory.exists():
            print(f"Directory not found: {directory}")
            return
        
        try:
            images, filenames = load_images_from_directory(directory, max_images=5)
            
            if not images:
                print("No images found in directory")
                return
            
            print(f"Loaded {len(images)} images from {directory.name}")
            for fn in filenames:
                print(f"  - {fn}")
            print()
            
            title = title_input.value.strip() if title_input.value.strip() else None
            
            current_fig = create_grid_with_matplotlib(
                images,
                filenames=filenames if show_labels_checkbox.value else None,
                gap_after_first=gap_slider.value,
                regular_gap=regular_gap_slider.value,
                show_labels=show_labels_checkbox.value,
                title=title
            )
            current_dir_name = directory.name
            
            plt.show()
            
        except Exception as e:
            print(f"Error: {e}")
            import traceback
            traceback.print_exc()

def on_save(b):
    global current_fig, current_dir_name
    
    with output_area:
        if current_fig is None:
            print("No grid to save. Generate a grid first.")
            return
        
        filename = f"grid_{current_dir_name}.png"
        filepath = OUTPUT_DIR / filename
        
        current_fig.savefig(filepath, dpi=150, bbox_inches='tight', 
                           facecolor=current_fig.get_facecolor())
        print(f"\nSaved: {filepath}")

generate_button.on_click(on_generate)
save_button.on_click(on_save)

# Display interface
interface = widgets.VBox([
    widgets.HTML("<h3>Image Grid Generator</h3>"),
    dir_input,
    widgets.HTML("<h4>Spacing</h4>"),
    gap_slider,
    regular_gap_slider,
    widgets.HTML("<h4>Options</h4>"),
    show_labels_checkbox,
    title_input,
    widgets.HTML("<hr>"),
    widgets.HBox([generate_button, save_button]),
    output_area
])

display(interface)

---

## Batch Processing

Generate grids for multiple directories at once.

In [None]:
def generate_grid_for_directory(directory, output_dir, 
                                 gap_after_first=0.15,
                                 regular_gap=0.02,
                                 show_labels=True):
    """
    Generate and save a grid for a single directory.
    """
    directory = Path(directory)
    output_dir = Path(output_dir)
    
    images, filenames = load_images_from_directory(directory, max_images=5)
    
    if len(images) < 5:
        print(f"  Skipping {directory.name}: only {len(images)} images found")
        return False
    
    fig = create_grid_with_matplotlib(
        images,
        filenames=filenames if show_labels else None,
        gap_after_first=gap_after_first,
        regular_gap=regular_gap,
        show_labels=show_labels,
        title=directory.name
    )
    
    output_dir.mkdir(parents=True, exist_ok=True)
    filename = f"grid_{directory.name}.png"
    filepath = output_dir / filename
    
    fig.savefig(filepath, dpi=150, bbox_inches='tight',
               facecolor=fig.get_facecolor())
    plt.close(fig)
    
    print(f"  Saved: {filename}")
    return True


def batch_generate_grids(parent_dir, output_dir=None,
                         gap_after_first=0.15,
                         regular_gap=0.02,
                         show_labels=True):
    """
    Generate grids for all subdirectories containing 5+ images.
    
    Args:
        parent_dir: Directory containing subdirectories with images
        output_dir: Where to save grids (default: parent_dir/grids)
    """
    parent_dir = Path(parent_dir)
    
    if output_dir is None:
        output_dir = parent_dir / "grids"
    output_dir = Path(output_dir)
    
    print(f"Scanning: {parent_dir}")
    print(f"Output: {output_dir}")
    print()
    
    subdirs = [d for d in parent_dir.iterdir() if d.is_dir() and d.name != 'grids']
    
    success_count = 0
    for subdir in sorted(subdirs):
        print(f"Processing: {subdir.name}")
        try:
            if generate_grid_for_directory(subdir, output_dir,
                                           gap_after_first=gap_after_first,
                                           regular_gap=regular_gap,
                                           show_labels=show_labels):
                success_count += 1
        except Exception as e:
            print(f"  Error: {e}")
    
    print(f"\nGenerated {success_count} grids")

In [None]:
# Example batch usage:
# batch_generate_grids(
#     parent_dir="/path/to/parent/containing/image/folders",
#     output_dir="/path/to/save/grids",  # optional
#     gap_after_first=0.15,
#     regular_gap=0.02,
#     show_labels=True
# )