In [None]:
import pandas as pd
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random
import os
from torchvision import transforms
from torchvision.utils import save_image

from lwise_imgproc_utils.enhance_helpers import *

def ensure_dir_exists(path):
    """Create directory and all parent directories if they don't exist."""
    directory = os.path.dirname(path)
    if directory:
        os.makedirs(directory, exist_ok=True)

def load_and_process_image(image_path):
    """Load image and convert to tensor in range [0,1]"""
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    return transform(img)

def create_visualization_grid(orig_dir, mod_dir, save_dir, n_images=8, specific_images=None, 
                            figsize=(20, 10), figure_save_path=None, seed=None, 
                            smooth_kernel_size=5, alpha=0.7, magnification_factor=5, splits=["val"], convert_to_ext=None):
    """
    Create a grid visualization and save all individual images.
    
    Args:
        orig_dir: Directory containing original images and dirmap.csv
        mod_dir: Directory containing modified images
        save_dir: Directory to save individual processed images
        n_images: Number of images to display (default 8, used only if specific_images not provided)
        specific_images: List of specific image paths to use (optional)
        figsize: Base figure size in inches (width, height) - will be adjusted based on n_images
        figure_save_path: Path to save the visualization figure (optional)
        seed: Random seed for reproducibility (optional)
        smooth_kernel_size: Kernel size for heatmap smoothing
        alpha: Transparency for heatmap overlay
    """
    if seed is not None:
        random.seed(seed)
        
    # Read directory mapping
    dirmap = pd.read_csv(os.path.join(orig_dir, 'dirmap.csv'))

    dirmap = dirmap[dirmap["split"].isin(splits)]
    
    # Select images
    if specific_images is not None:
        selected_paths = specific_images
        n_images = len(specific_images)  # Adjust n_images based on provided paths
    else:
        selected_paths = random.sample(dirmap['im_path'].tolist(), n_images)
    
    # Adjust figure size based on number of images
    adjusted_figsize = (figsize[0] * n_images / 8, figsize[1])
    
    # Create figure
    fig, axes = plt.subplots(4, n_images, figsize=adjusted_figsize)
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    
    # Handle case where n_images = 1 (axes would be 1D)
    if n_images == 1:
        axes = axes.reshape(-1, 1)
    
    # Row titles
    row_titles = ['Original', 'Enhanced', 'Difference\n(magnified 5x)', 'Difference\nheatmap', ]
    for idx, title in enumerate(row_titles):
        fig.text(0.02, 0.75 - idx*0.25, title, 
                va='center', ha='right', fontsize=12, fontweight='bold')
    
    for col, im_path in enumerate(selected_paths):
        # Load original and modified images
        orig_path = os.path.join(orig_dir, im_path)
        mod_path = os.path.join(mod_dir, im_path)
        
        orig_img = load_and_process_image(orig_path)
        mod_img = load_and_process_image(mod_path)
        
        # Generate difference
        im_diff = mod_img - orig_img
        #im_diff_norm = (im_diff - im_diff.min()) / (im_diff.max() - im_diff.min())

        #im_diff_norm = im_diff

        # diff_clipped = torch.clamp(im_diff, -1.0, 1.0)
        # im_diff_norm = (diff_clipped + 1.0) / 2.0

        im_diff_magnified = im_diff*magnification_factor
        diff_clipped = torch.clamp(im_diff_magnified, -1.0, 1.0)
        im_diff_norm = (diff_clipped + 1.0) / 2.0
        
        # Generate heatmap and overlay
        heatmap = create_heatmap(im_diff, smooth_kernel_size)
        overlay = create_overlay(orig_img, heatmap, alpha)
        
        # Save individual images
        save_base_path = os.path.join(save_dir, im_path)
        base_path, ext = os.path.splitext(save_base_path)
        
        # Ensure directory exists for each image
        ensure_dir_exists(save_base_path)

        if convert_to_ext:
            ext = convert_to_ext
        
        # Save all versions
        save_paths = {
            'orig': f"{base_path}_ORIG{ext}",
            'mod': f"{base_path}_ENH{ext}",
            'overlay': f"{base_path}_HEATMAP_OVERLAID_ORIG{ext}",
            'diff': f"{base_path}_DIFF{ext}",
            'heatmap': f"{base_path}_HEATMAP{ext}"
        }
        
        # Save images
        save_image(orig_img, save_paths['orig'])
        save_image(mod_img, save_paths['mod'])
        save_image(overlay, save_paths['overlay'])
        save_image(im_diff_norm, save_paths['diff'])
        save_image(heatmap, save_paths['heatmap'])
        
        # Collect visualizations for plotting
        images = {
            'orig': orig_img,
            'mod': mod_img,
            'diff': im_diff_norm,
            'overlay': overlay,
        }
        
        # Plot each image in its respective position
        for row, (key, img) in enumerate(images.items()):
            ax = axes[row, col]
            
            # Convert tensor to numpy and handle channel dimension
            if img.shape[0] == 1:  # Grayscale
                img_np = img.squeeze(0).numpy()
                ax.imshow(img_np, cmap='gray')
            else:  # RGB
                img_np = img.permute(1, 2, 0).numpy()
                ax.imshow(np.clip(img_np, 0, 1))
            
            # Remove axes ticks
            ax.set_xticks([])
            ax.set_yticks([])
            
            # Add column title (only for first row)
            if row == 0:
                ax.set_title(f'Image {col+1}', pad=5)
    
    # Adjust layout
    plt.tight_layout(rect=[0.05, 0, 1, 1])
    
    # Save figure if requested
    if figure_save_path:
        # Ensure directory exists for figure
        ensure_dir_exists(figure_save_path)
        # Save as PDF with 300 DPI
        plt.savefig(figure_save_path, bbox_inches='tight', dpi=600, format='pdf')
        
    return fig, axes

# Example usage:
"""
# Basic usage with random images:
fig, axes = create_visualization_grid(
    orig_dir='path/to/original/images',
    mod_dir='path/to/modified/images',
    save_dir='path/to/save/directory',
    figure_save_path='visualization_grid.pdf',
    smooth_kernel_size=5,
    alpha=0.7
)

# Usage with specific images:
specific_images = [
    'path/to/image1.png',
    'path/to/image2.png',
    'path/to/image3.png'
]
fig, axes = create_visualization_grid(
    orig_dir='path/to/original/images',
    mod_dir='path/to/modified/images',
    save_dir='path/to/save/directory',
    specific_images=specific_images,
    figure_save_path='visualization_grid_specific.pdf',
    smooth_kernel_size=5,
    alpha=0.7
)
"""

In [None]:
specific_images=[
    "images/ISIC_0024315.jpg",
    "images/ISIC_0024705.jpg",
    "images/ISIC_0024912.jpg",
    "images/ISIC_0031298.jpg",
]
fig, axes = create_visualization_grid(
    orig_dir='../imgproc_code/data/HAM10000/HAM10000_natural',
    mod_dir='../imgproc_code/data/HAM10000/HAM10000_8_0.5_16_logit_diverge',
    save_dir='heatmap_viz/ham4',
    convert_to_ext='.png',
    figure_save_path='visualization_grid.png',
    specific_images=specific_images,
    smooth_kernel_size=21,
)

In [None]:
specific_images=[
    "images/MHIST_ayy.png",
    "images/MHIST_bsy.png",
]
fig, axes = create_visualization_grid(
    orig_dir='../imgproc_code/data/mhist/mhist_original',
    mod_dir='../imgproc_code/data/mhist/mhist_8_0.5_16_logit_diverge',
    save_dir='heatmap_viz/mhist',
    convert_to_ext='.png',
    figure_save_path='visualization_grid.png',
    specific_images=specific_images,
    smooth_kernel_size=21,
)

In [None]:
specific_images=[
    "train/01233_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_aversata/d8e21ccf-cd9c-46f0-826e-d962d61b0a62.jpg",
    "train/01234_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_biselata/b365cd7a-8905-441e-a0f5-735d146d66ff.jpg",
    "train/01239_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_seriata/dc9bae6a-9ec0-4fa7-abae-14596d9d601e.jpg",
    "train/01240_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_tacturata/33dc5c3d-a080-4a3f-8a12-52cdb42c5576.jpg",
]
fig, axes = create_visualization_grid(
    orig_dir='../imgproc_code/data/idaea4/idaea4_natural',
    mod_dir='../imgproc_code/data/idaea4/idaea4_8_0.5_16_logit_diverge',
    save_dir='heatmap_viz/idaea4',
    convert_to_ext='.png',
    figure_save_path='visualization_grid.png',
    specific_images=specific_images,
    smooth_kernel_size=21,
    splits=["train", "val"]
)

In [None]:
specific_images=[
    "val/monkey/ILSVRC2012_val_00031957.JPEG",
    "val/lizard/ILSVRC2012_val_00013870.JPEG",
    "val/insect/ILSVRC2012_val_00013395.JPEG",
    "val/fish/ILSVRC2012_val_00047682.JPEG",
    "val/crab/ILSVRC2012_val_00020816.JPEG",
    "val/bird/ILSVRC2012_val_00030431.JPEG",
    "val/dog/ILSVRC2012_val_00030567.JPEG",
]
fig, axes = create_visualization_grid(
    orig_dir='../imgproc_code/data/imagenet16_resized',
    mod_dir='../imgproc_code/data/imagenet16_20_0.5_40_logit',
    figure_save_path='fig_outputs/visualization_grid.pdf',
    specific_images=specific_images,
    save_dir='heatmap_viz/imagenet16',
    convert_to_ext='.png',
    smooth_kernel_size=21,
    splits=["val"]
)

In [None]:
def create_average_heatmap(orig_dir, mod_dir, dirmap_path, save_dir, 
                           figure_save_path=None, smooth_kernel_size=5, 
                           splits=["val"], convert_to_ext=None):
    """
    Create an average heatmap across all images in the specified splits.
    
    Args:
        orig_dir: Directory containing original images
        mod_dir: Directory containing modified images
        dirmap_path: Path to the CSV file containing image paths and splits
        save_dir: Directory to save the average heatmap
        figure_save_path: Path to save the visualization figure (optional)
        smooth_kernel_size: Kernel size for heatmap smoothing
        splits: List of splits to include (e.g., ["train", "val", "test"])
        convert_to_ext: Extension to convert images to when saving (optional)
    
    Returns:
        avg_heatmap: The average heatmap as a tensor
    """
    import pandas as pd
    import torch
    import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    from torchvision.utils import save_image
    
    # Read directory mapping
    dirmap = pd.read_csv(dirmap_path)
    
    # Filter by splits
    dirmap = dirmap[dirmap["split"].isin(splits)]
    
    print(f"Processing {len(dirmap)} images from splits: {splits}")
    
    # Variables to accumulate difference magnitudes directly
    magnitude_sum = None
    count = 0
    
    # Process each image
    for idx, row in enumerate(dirmap.itertuples()):
        im_path = row.im_path
        
        # Load original and modified images
        orig_path = os.path.join(orig_dir, im_path)
        mod_path = os.path.join(mod_dir, im_path)
        
        try:
            orig_img = load_and_process_image(orig_path)
            mod_img = load_and_process_image(mod_path)
            
            # Generate difference
            im_diff = mod_img - orig_img
            
            # Calculate magnitude of difference across channels (similar to create_heatmap)
            if im_diff.shape[0] == 3:  # RGB image
                magnitude = torch.sqrt(torch.sum(im_diff ** 2, dim=0))
            else:
                magnitude = torch.abs(im_diff.squeeze())
            
            # Accumulate magnitude (without normalization)
            if magnitude_sum is None:
                magnitude_sum = magnitude.clone()
                image_shape = magnitude.shape
            else:
                # Handle potential dimension mismatch by skipping
                if magnitude.shape != image_shape:
                    print(f"Warning: Image {im_path} has different dimensions. Skipping.")
                    continue
                
                magnitude_sum += magnitude
            
            count += 1
            
            # Print progress every 100 images
            if (idx + 1) % 100 == 0 or (idx + 1) == len(dirmap):
                print(f"Processed {idx + 1}/{len(dirmap)} images")
                
        except Exception as e:
            print(f"Error processing {im_path}: {e}")
            continue
    
    # Calculate average magnitude
    if count > 0:
        avg_magnitude = magnitude_sum / count
        print(f"Created average magnitude from {count} images")
    else:
        raise ValueError("No valid images were processed")
    
    # Optional smoothing of the averaged magnitude
    if smooth_kernel_size:
        kernel_size = smooth_kernel_size
        channels = 1
        kernel = torch.ones(channels, 1, kernel_size, kernel_size) / (kernel_size * kernel_size)
        avg_magnitude = avg_magnitude.unsqueeze(0).unsqueeze(0)
        avg_magnitude = F.conv2d(avg_magnitude, kernel.to(avg_magnitude.device), padding=kernel_size//2)
        avg_magnitude = avg_magnitude.squeeze()
    
    # Print statistics before normalization
    print(f"Average magnitude statistics - Min: {avg_magnitude.min().item()}, Max: {avg_magnitude.max().item()}, Mean: {avg_magnitude.mean().item()}")
    
    # Normalize after averaging
    norm_magnitude = (avg_magnitude - avg_magnitude.min()) / (avg_magnitude.max() - avg_magnitude.min() + 1e-8)
    
    # Create RGB heatmap (red is high difference, blue is low)
    avg_heatmap = torch.zeros(3, norm_magnitude.shape[0], norm_magnitude.shape[1])
    avg_heatmap[0] = norm_magnitude  # Red channel
    avg_heatmap[2] = 1 - norm_magnitude  # Blue channel
    
    # Save the average heatmap
    if save_dir:
        # Create full directory structure
        os.makedirs(save_dir, exist_ok=True)
        
        ext = convert_to_ext if convert_to_ext else '.png'
        save_path = os.path.join(save_dir, f"average_heatmap{ext}")
        save_image(avg_heatmap, save_path)
        print(f"Average heatmap saved to {save_path}")
    
    # Create and save visualization figure if requested
    if figure_save_path:
        # Create directory for figure
        os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
        fig, ax = plt.subplots(figsize=(10, 10))
        
        # Convert tensor to numpy for plotting
        if avg_heatmap.shape[0] == 3:  # If RGB (from create_heatmap)
            heatmap_np = avg_heatmap.permute(1, 2, 0).numpy()
            im = ax.imshow(np.clip(heatmap_np, 0, 1))
            
            # Create custom blue-to-red colormap to match the heatmap
            from matplotlib.colors import LinearSegmentedColormap
            blue_to_red = LinearSegmentedColormap.from_list('BlueRed', [(0, 0, 1), (1, 0, 0)])
            
            # Create a separate colorbar with matching colors
            import matplotlib as mpl
            norm = mpl.colors.Normalize(vmin=0, vmax=1)
            sm = plt.cm.ScalarMappable(cmap=blue_to_red, norm=norm)
            sm.set_array([])
            
            cbar = fig.colorbar(sm, ax=ax, shrink=0.8, pad=0.02)
        else:  # Fallback for grayscale
            heatmap_np = avg_heatmap.squeeze(0).numpy()
            im = ax.imshow(heatmap_np, cmap='coolwarm', vmin=0, vmax=1)
            cbar = fig.colorbar(im, ax=ax, shrink=0.8, pad=0.02, linewidth=0)
        
        # Set up colorbar ticks and label
        cbar.set_label('Normalized mean pixel-value difference', fontsize=28)
        cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
        cbar.ax.tick_params(labelsize=26)
        
        # Make colorbar border thicker
        cbar.outline.set_linewidth(3)  # Adjust thickness as needed

        # Make ticks thicker
        cbar.ax.tick_params(width=3, length=10)  # Adjust width and length as needed
        
        # Remove the title
        ax.axis('off')
        
        # Save figure with tight layout
        plt.tight_layout()
        plt.savefig(figure_save_path, bbox_inches='tight', dpi=600)
        print(f"Figure saved to {figure_save_path}")
        plt.close()
    
    return avg_heatmap

In [None]:
avg_heatmap = create_average_heatmap(
    orig_dir='../imgproc_code/data/imagenet16_resized',
    mod_dir='../imgproc_code/data/imagenet16_20_0.5_40_logit',
    dirmap_path='../imgproc_code/data/imagenet16_resized/dirmap.csv',
    figure_save_path='fig_outputs/avg_heatmap.pdf',
    save_dir='heatmap_viz/imagenet16/average',
    convert_to_ext='.png',
    smooth_kernel_size=None,
    splits=["val"]
)

# Filtering images by difficulty (ground truth logit)

In [None]:
import pandas as pd
import os
import shutil
from collections import defaultdict

def filter_and_save_by_logit(orig_dir, save_dir, logit_min, logit_max, splits=["val"], max_images=None, max_per_class=None):
    """
    Filter images based on robust_gt_logit values and save them with logit values in filenames,
    maintaining roughly equal distribution across classes.
    
    Args:
        orig_dir: Directory containing original images and dirmap.csv
        save_dir: Directory to save filtered images
        logit_min: Minimum robust_gt_logit value to include
        logit_max: Maximum robust_gt_logit value to include
        splits: List of dataset splits to include (default: ["val"])
        max_images: Maximum total number of images to save (optional)
        max_per_class: Maximum number of images per class (optional)
            If neither max_images nor max_per_class is specified, all matching images will be saved
            If both are specified, both limits will be applied
    
    Returns:
        pd.DataFrame: DataFrame containing information about the saved images
    """
    # Read directory mapping
    dirmap = pd.read_csv(os.path.join(orig_dir, 'dirmap_logits.csv'))
    
    # Filter by split and logit range
    mask = (dirmap["split"].isin(splits)) & \
           (dirmap["robust_gt_logit"] >= logit_min) & \
           (dirmap["robust_gt_logit"] <= logit_max)
    
    filtered_df = dirmap[mask].copy()
    
    # Group by class
    class_groups = filtered_df.groupby('class')
    unique_classes = filtered_df['class'].unique()
    n_classes = len(unique_classes)
    
    # Calculate images per class
    if max_images is not None:
        images_per_class = max(1, max_images // n_classes)
        if max_per_class is not None:
            images_per_class = min(images_per_class, max_per_class)
    elif max_per_class is not None:
        images_per_class = max_per_class
    else:
        # If no limits specified, use the size of the smallest class group
        images_per_class = min(len(group) for _, group in class_groups)
    
    # Sample equal numbers from each class
    balanced_dfs = []
    for class_name, group in class_groups:
        sampled = group.sample(n=min(len(group), images_per_class))
        balanced_dfs.append(sampled)
    
    balanced_df = pd.concat(balanced_dfs, ignore_index=True)
    
    # If max_images is specified and we're still over the limit, sample randomly
    if max_images is not None and len(balanced_df) > max_images:
        balanced_df = balanced_df.sample(n=max_images)
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Process and save each image
    saved_records = []
    class_counts = defaultdict(int)
    
    for _, row in balanced_df.iterrows():
        # Get original image path
        orig_path = os.path.join(orig_dir, row['im_path'])
        
        # Create new filename with logit value and class
        base_name, ext = os.path.splitext(os.path.basename(row['im_path']))
        logit_str = f"{row['robust_gt_logit']:.3f}".replace('-', 'neg')
        new_filename = f"{base_name}_class_{row['class']}_logit_{logit_str}{ext}"
        
        # Create subdirectories if needed
        rel_dir = os.path.dirname(row['im_path'])
        save_subdir = os.path.join(save_dir, rel_dir)
        os.makedirs(save_subdir, exist_ok=True)
        
        # Full save path
        save_path = os.path.join(save_subdir, new_filename)
        
        # Copy image to new location
        shutil.copy2(orig_path, save_path)
        
        # Update counters and records
        class_counts[row['class']] += 1
        saved_records.append({
            'original_path': row['im_path'],
            'saved_path': os.path.relpath(save_path, save_dir),
            'robust_gt_logit': row['robust_gt_logit'],
            'class': row['class']
        })
    
    # Create DataFrame of saved images
    saved_df = pd.DataFrame(saved_records)
    
    # Save metadata
    metadata_path = os.path.join(save_dir, 'filtered_images_metadata.csv')
    saved_df.to_csv(metadata_path, index=False)
    
    # Print summary
    print(f"\nSaved {len(saved_df)} images total with robust_gt_logit values between {logit_min:.3f} and {logit_max:.3f}")
    print("\nImages per class:")
    for class_name, count in class_counts.items():
        print(f"Class {class_name}: {count} images")
    print(f"\nMetadata saved to: {metadata_path}")
    
    return saved_df

# Example usage:
"""
filtered_df = filter_and_save_by_logit(
    orig_dir='path/to/original/images',
    save_dir='path/to/save/filtered/images',
    logit_min=-1.0,
    logit_max=1.0,
    splits=["val"],
    max_images=100,  # optional: total image limit
    max_per_class=20  # optional: per-class limit
)
"""

In [None]:
filtered_df = filter_and_save_by_logit(
    orig_dir='../imgproc_code/data/imagenet16_resized',
    save_dir='heatmap_viz/imagenet16_difficulty',
    logit_min=20,
    logit_max=30,
    splits=["val"],
    max_images=200
)