This notebook generates visualizations of different versions of image datasets under different conditions (eg., differing degrees of enhancement, different enhancement algorithms, different guide models, etc). We used it to find example images for many of the figures in the L-WISE paper. 

In [None]:
import os

# Set the current working directory to the parent directory (which contains the "notebooks" directory among others)
changed_dir = False
if not changed_dir and os.path.exists("./make_figs.ipynb"):
  os.chdir(os.path.dirname(os.getcwd()))
  changed_dir = True
assert os.path.exists("./notebooks/make_figs.ipynb"), "Make sure your working directory starts in 'notebooks'"

In [None]:
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision.transforms.functional import pil_to_tensor
from pathlib import Path

def find_split_directory(base_dir, split):
    """
    Find the split directory (train/val/test) within the given directory structure.
    Handles both direct split folders and nested structures.
    """
    base_path = Path(base_dir)
    
    # Case 1: Split directory exists directly in base_dir
    direct_split = base_path / split
    if direct_split.is_dir():
        return direct_split
    
    # Case 2: Split directory exists in a subdirectory
    # Look only one level deep for the split directory
    for subdir in base_path.iterdir():
        if subdir.is_dir():
            split_dir = subdir / split
            if split_dir.is_dir():
                return split_dir
    
    return None

def get_image_files(directory):
    """
    Recursively get all image files in a directory.
    """
    image_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                image_files.append(Path(root) / file)
    return image_files

def display_images(data_dir, num_images_to_display=10, subdirs_order=None, subdir_title_dict={}, split=None):
    data_path = Path(data_dir)
    
    # Get all valid subdirectories
    if subdirs_order:
        subdirs = [data_path / d for d in subdirs_order]
        subdirs = [d for d in subdirs if d.exists()]
    else:
        subdirs = [d for d in data_path.iterdir() if d.is_dir()]
        # Sort with "originals" or "natural" first
        priority_names = ['originals', 'natural', 'robustness_natural']
        subdirs.sort(key=lambda x: (x.name.lower() not in priority_names, x.name.lower()))
    
    if not subdirs:
        raise ValueError(f"No valid subdirectories found in {data_dir}")

    # Find split directories for each subdir
    split_dirs = []
    for subdir in subdirs:
        if split:
            split_dir = find_split_directory(subdir, split)
            if split_dir:
                split_dirs.append((subdir, split_dir))
            else:
                print(f"Warning: Split directory '{split}' not found in {subdir}")
        else:
            split_dirs.append((subdir, subdir))
    
    if not split_dirs:
        raise ValueError(f"No valid split directories found for split '{split}'")

    # Get random images from the first directory
    first_subdir_images = get_image_files(split_dirs[0][1])
    if not first_subdir_images:
        raise ValueError(f"No images found in {split_dirs[0][1]}")
    
    selected_images = random.sample(first_subdir_images, min(num_images_to_display, len(first_subdir_images)))
    
    # Create the plot
    fig, axs = plt.subplots(len(selected_images), len(split_dirs), 
                           figsize=(5*len(split_dirs), 5*len(selected_images)), 
                           dpi=600)
    if len(selected_images) == 1:
        axs = axs.reshape(1, -1)
    
    # Display images
    for row, orig_image_path in enumerate(selected_images):
        orig_img = None
        base_image_name = orig_image_path.name
        print(orig_image_path)
        
        for col, (subdir, split_dir) in enumerate(split_dirs):
            # Search for matching image
            matching_images = list(split_dir.rglob(base_image_name))
            image_path = matching_images[0] if matching_images else None
            
            if image_path:
                image = Image.open(image_path)
                subdir_name = subdir.name
                title = subdir_title_dict.get(subdir_name, subdir_name)
                
                if col == 0:
                    orig_img = image
                    eps = 0.0
                else:
                    eps = torch.norm(pil_to_tensor(orig_img).float()/255 - 
                                   pil_to_tensor(image).float()/255)
                
                axs[row, col].imshow(image)
                axs[row, col].axis('off')
                axs[row, col].set_title(f"{title}\neps={eps:.4f}")
            else:
                axs[row, col].axis('off')
                axs[row, col].set_title(f"{subdir.name}\nImage not found")
    
    plt.tight_layout()
    plt.savefig(f"./notebooks/fig_outputs/perturb_methods_compare.pdf", dpi=600, format='pdf', bbox_inches='tight')
    plt.show()

In [None]:
data_dir = "imgproc_code/data"
num_images_to_display = 5

print(os.getcwd())

subdirs_order = ["imagenet16_resized", "imagenet16_wormholes_vanilla_resnet50", "imagenet16_cutmix_resnet50", "imagenet16_wormholes_eps1_resnet50", "imagenet16_wormholes_eps3_resnet50", "imagenet16_wormholes_eps10_resnet50", "imagenet16_xcit_augmented"] 

display_images(data_dir, num_images_to_display, subdirs_order, split="val") 

In [None]:
data_dir = "imgproc_code/data"
num_images_to_display = 10

print(os.getcwd())

subdirs_order = ["imagenet16_resized", "imagenet16_epochs/0_checkpoint", "imagenet16_epochs/9_checkpoint", "imagenet16_epochs/89_checkpoint"] 

display_images(data_dir, num_images_to_display, subdirs_order, split="val") 