# Latent Conditioner Input Data Preview

This notebook previews the input data for the latent conditioner using the same settings as `read_latent_conditioner_dataset_img` and demonstrates PCA preprocessing effects.

In [None]:
# Import required libraries
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import natsort
from sklearn.decomposition import PCA
import math
import torch

# Import PCA preprocessor
import sys
sys.path.append('modules')
from pca_preprocessor import PCAPreprocessor

# Set matplotlib parameters for better plots
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

## Configuration Settings

These settings match the configuration from `input_data/condition.txt`:

In [None]:
# Configuration from condition.txt
param_dir = '/images'  # Image directory
param_data_type = '.png'  # Image file type
DEFAULT_IMAGE_SIZE = 256  # High resolution for sharp outline detection
INTERPOLATION_METHOD = cv2.INTER_CUBIC  # High-quality interpolation

# PCA settings
use_pca = False  # Set to True to enable PCA preprocessing
pca_components = 256  # Number of PCA components (must be < n_samples)
pca_patch_size = 0  # 0=full image PCA, >0=patch-based PCA

# Other settings
n_sample = 484  # Number of samples
debug_mode = 1  # Enable debug output

print(f"Configuration:")
print(f"  Image directory: {param_dir}")
print(f"  Image type: {param_data_type}")
print(f"  Image size: {DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE}")
print(f"  Expected samples: {n_sample}")
print(f"  PCA enabled: {use_pca}")
print(f"  PCA components: {pca_components}")
print(f"  PCA patch size: {pca_patch_size} (0=full image)")

## Load and Preview Raw Images

This section replicates the image loading from `read_latent_conditioner_dataset_img`:

In [None]:
# Load images using the same method as read_latent_conditioner_dataset_img
cur_dir = os.getcwd()
file_dir = cur_dir + param_dir

print(f"Looking for images in: {file_dir}")

if not os.path.exists(file_dir):
    print(f"❌ Directory {file_dir} does not exist!")
    print("Please ensure the images directory exists and contains .png files")
else:
    # Get list of image files
    files = [f for f in os.listdir(file_dir) if f.endswith(param_data_type)]
    files = natsort.natsorted(files)
    
    print(f"Found {len(files)} image files")
    
    if len(files) == 0:
        print(f"❌ No {param_data_type} files found in {file_dir}")
    else:
        print(f"First 5 files: {files[:5]}")
        
        # Load first few images for preview
        preview_count = min(len(files), 6)
        raw_images = np.zeros((len(files), DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE))
        
        for i, file in enumerate(files):
            if debug_mode == 1 and i < 5:
                print(f"Loading: {file}")
            file_path = os.path.join(file_dir, file)
            im = cv2.imread(file_path, 0)  # Grayscale
            if im is None:
                print(f"❌ Failed to load {file}")
                continue
            resized_im = cv2.resize(im, (DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE), 
                                  interpolation=INTERPOLATION_METHOD)
            raw_images[i] = resized_im
        
        print(f"✓ Loaded {len(files)} images with shape: {raw_images.shape}")
        print(f"  Image value range: [{raw_images.min():.1f}, {raw_images.max():.1f}]")
        print(f"  Image dtype: {raw_images.dtype}")

In [None]:
# Display first few raw images
if 'raw_images' in locals() and len(files) > 0:
    preview_count = min(len(files), 6)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Raw Input Images (First 6 samples)', fontsize=16)
    
    for i in range(preview_count):
        row = i // 3
        col = i % 3
        
        axes[row, col].imshow(raw_images[i], cmap='gray')
        axes[row, col].set_title(f'Image {i+1}: {files[i]}')
        axes[row, col].axis('off')
        
        # Add statistics
        img_mean = raw_images[i].mean()
        img_std = raw_images[i].std()
        axes[row, col].text(0.02, 0.98, f'μ={img_mean:.1f}\\nσ={img_std:.1f}', 
                           transform=axes[row, col].transAxes, 
                           verticalalignment='top',
                           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Hide empty subplots
    for i in range(preview_count, 6):
        row = i // 3
        col = i % 3
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show histogram of pixel values
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(raw_images.flatten(), bins=50, alpha=0.7, edgecolor='black')
    plt.title('Pixel Value Distribution (All Images)')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    # Show per-image statistics
    img_means = [raw_images[i].mean() for i in range(min(20, len(files)))]
    img_stds = [raw_images[i].std() for i in range(min(20, len(files)))]
    
    x_pos = range(len(img_means))
    plt.errorbar(x_pos, img_means, yerr=img_stds, fmt='o-', capsize=3)
    plt.title('Per-Image Statistics (First 20 images)')
    plt.xlabel('Image Index')
    plt.ylabel('Pixel Value')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## PCA Preprocessing Analysis

This section demonstrates the PCA preprocessing using the `PCAPreprocessor` class:

In [None]:
# PCA Analysis and Preprocessing
if 'raw_images' in locals() and len(files) > 0:
    print("\\n=== PCA Preprocessing Analysis ===")
    
    # Check if we can apply PCA with current settings
    n_samples = raw_images.shape[0]
    n_features = DEFAULT_IMAGE_SIZE * DEFAULT_IMAGE_SIZE
    max_components = min(n_samples, n_features)
    
    print(f"Dataset info:")
    print(f"  Number of samples: {n_samples}")
    print(f"  Features per image: {n_features} ({DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE})")
    print(f"  Maximum PCA components: {max_components}")
    print(f"  Requested PCA components: {pca_components}")
    
    # Adjust PCA components if necessary
    if pca_components > max_components:
        print(f"❌ Requested {pca_components} components exceeds maximum {max_components}")
        adjusted_components = min(256, max_components - 1)  # Leave some margin
        print(f"✓ Adjusting to {adjusted_components} components")
        pca_components = adjusted_components
    
    # Initialize PCA preprocessor
    pca_preprocessor = PCAPreprocessor(
        n_components=pca_components,
        patch_size=pca_patch_size if pca_patch_size > 0 else None
    )
    
    print(f"\\nPCA Configuration:")
    print(f"  Components: {pca_components}")
    print(f"  Patch size: {pca_patch_size if pca_patch_size > 0 else 'Full image'}")
    
    # Fit PCA on the data
    print("\\nFitting PCA model...")
    pca_preprocessor.fit(raw_images)
    
    # Transform images using PCA
    print("\\nTransforming images with PCA...")
    pca_tensor = pca_preprocessor.transform(raw_images)
    
    print(f"PCA transformation results:")
    print(f"  Original shape: {raw_images.shape}")
    print(f"  PCA tensor shape: {pca_tensor.shape}")
    print(f"  PCA tensor dtype: {pca_tensor.dtype}")
    print(f"  Output shape: {pca_preprocessor.get_output_shape()}")
    print(f"  Output channels: {pca_preprocessor.get_output_channels()}")
    
    # Convert to format expected by latent conditioner
    if len(pca_tensor.shape) == 4:  # (n_samples, channels, height, width)
        pca_data_flattened = pca_tensor.view(pca_tensor.shape[0], -1).numpy()
        data_shape = pca_tensor.shape[2:]  # (height, width)
    else:
        pca_data_flattened = pca_tensor.numpy()
        data_shape = pca_preprocessor.get_output_shape()
    
    print(f"\\nFinal processed data:")
    print(f"  Flattened shape: {pca_data_flattened.shape}")
    print(f"  Data shape for model: {data_shape}")
    print(f"  Dimensionality reduction: {n_features} → {pca_data_flattened.shape[1]} ({100*pca_data_flattened.shape[1]/n_features:.1f}%)")
    
    # Show explained variance
    if hasattr(pca_preprocessor.pca, 'explained_variance_ratio_'):
        cumulative_variance = np.cumsum(pca_preprocessor.pca.explained_variance_ratio_)
        print(f"  Explained variance: {cumulative_variance[-1]:.1%}")
        
        # Plot explained variance
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, 'b-o', markersize=3)
        plt.xlabel('Number of Components')
        plt.ylabel('Cumulative Explained Variance')
        plt.title('PCA Explained Variance')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0.95, color='r', linestyle='--', alpha=0.7, label='95% threshold')
        plt.axhline(y=0.99, color='g', linestyle='--', alpha=0.7, label='99% threshold')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(range(1, min(50, len(pca_preprocessor.pca.explained_variance_ratio_)) + 1), 
                pca_preprocessor.pca.explained_variance_ratio_[:min(50, len(pca_preprocessor.pca.explained_variance_ratio_))], 
                'r-o', markersize=3)
        plt.xlabel('Component Number')
        plt.ylabel('Explained Variance Ratio')
        plt.title('Individual Component Variance (First 50)')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

In [None]:
# Visualize PCA components and reconstructions
if 'pca_preprocessor' in locals() and pca_preprocessor.is_fitted:
    print("\\n=== PCA Visualization ===")
    
    # Show first few PCA components as images
    if not pca_preprocessor.patch_size:  # Full image PCA
        components_to_show = min(6, pca_preprocessor.pca.components_.shape[0])
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle('First 6 PCA Components', fontsize=16)
        
        for i in range(components_to_show):
            row = i // 3
            col = i % 3
            
            # Reshape component to image
            component_img = pca_preprocessor.pca.components_[i].reshape(DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)
            
            # Normalize for display
            component_img = (component_img - component_img.min()) / (component_img.max() - component_img.min())
            
            axes[row, col].imshow(component_img, cmap='RdBu_r')
            axes[row, col].set_title(f'Component {i+1}\\n(Var: {pca_preprocessor.pca.explained_variance_ratio_[i]:.3f})')
            axes[row, col].axis('off')
        
        # Hide empty subplots
        for i in range(components_to_show, 6):
            row = i // 3
            col = i % 3
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    # Show original vs PCA-processed images
    print("\\nComparing original vs PCA-processed images:")
    
    # Take first few samples for comparison
    samples_to_compare = min(3, len(files))
    
    fig, axes = plt.subplots(2, samples_to_compare, figsize=(5*samples_to_compare, 8))
    if samples_to_compare == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(samples_to_compare):
        # Original image
        axes[0, i].imshow(raw_images[i], cmap='gray')
        axes[0, i].set_title(f'Original Image {i+1}')
        axes[0, i].axis('off')
        
        # PCA-processed visualization
        if len(pca_tensor.shape) == 4:  # (n_samples, channels, height, width)
            pca_img = pca_tensor[i, 0].numpy()  # Take first channel
        else:
            # Reshape flattened PCA data for visualization
            pca_data_2d = pca_data_flattened[i]
            if len(data_shape) == 2:
                pca_img = pca_data_2d.reshape(data_shape)
            else:
                # For 1D data, create a simple visualization
                side_len = int(np.sqrt(len(pca_data_2d)))
                if side_len * side_len == len(pca_data_2d):
                    pca_img = pca_data_2d.reshape(side_len, side_len)
                else:
                    # Pad to make square
                    target_len = side_len + 1
                    padded_data = np.pad(pca_data_2d, (0, target_len*target_len - len(pca_data_2d)))
                    pca_img = padded_data.reshape(target_len, target_len)
        
        axes[1, i].imshow(pca_img, cmap='viridis')
        axes[1, i].set_title(f'PCA Processed {i+1}\\n({pca_components} components)')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\\n✓ PCA preprocessing complete!")
    print(f"  Original data: {raw_images.shape} → {raw_images.nbytes / 1024**2:.1f} MB")
    print(f"  PCA data: {pca_data_flattened.shape} → {pca_data_flattened.nbytes / 1024**2:.1f} MB")
    print(f"  Memory reduction: {100 * (1 - pca_data_flattened.nbytes / raw_images.nbytes):.1f}%")

## Data Summary and Recommendations

Final analysis and recommendations for the latent conditioner training:

In [None]:
# Summary and recommendations
if 'raw_images' in locals() and len(files) > 0:
    print("\\n" + "="*60)
    print("         LATENT CONDITIONER INPUT SUMMARY")
    print("="*60)
    
    print(f"\\n📊 Dataset Information:")
    print(f"   • Total samples: {len(files)}")
    print(f"   • Image size: {DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE}")
    print(f"   • Image format: {param_data_type}")
    print(f"   • Directory: {file_dir}")
    
    print(f"\\n🔢 Data Statistics:")
    print(f"   • Pixel value range: [{raw_images.min():.1f}, {raw_images.max():.1f}]")
    print(f"   • Mean pixel value: {raw_images.mean():.1f}")
    print(f"   • Std pixel value: {raw_images.std():.1f}")
    print(f"   • Total memory: {raw_images.nbytes / 1024**2:.1f} MB")
    
    if 'pca_preprocessor' in locals() and pca_preprocessor.is_fitted:
        print(f"\\n🔍 PCA Analysis:")
        print(f"   • PCA components: {pca_components}")
        print(f"   • Explained variance: {np.sum(pca_preprocessor.pca.explained_variance_ratio_):.1%}")
        print(f"   • Dimensionality reduction: {100 * (1 - pca_data_flattened.shape[1] / (DEFAULT_IMAGE_SIZE**2)):.1f}%")
        print(f"   • Memory reduction: {100 * (1 - pca_data_flattened.nbytes / raw_images.nbytes):.1f}%")
        print(f"   • Final data shape: {pca_data_flattened.shape}")
    
    print(f"\\n⚙️ Configuration Recommendations:")
    
    # Check sample size vs PCA components
    if pca_components >= len(files):
        recommended_components = min(len(files) // 2, 256)
        print(f"   ⚠️  PCA components ({pca_components}) should be < samples ({len(files)})")
        print(f"   ✓  Recommend: pca_components = {recommended_components}")
    else:
        print(f"   ✓  PCA components ({pca_components}) < samples ({len(files)}) ✓")
    
    # Check if images are outline/edge data
    edge_content = np.mean([cv2.Canny(raw_images[i].astype(np.uint8), 50, 150).sum() for i in range(min(10, len(files)))])
    if edge_content > raw_images.size / len(files) * 0.01:  # Heuristic for outline detection
        print(f"   ✓  Images appear to contain outline/edge data")
        print(f"   ✓  Outline-preserving augmentations will be applied during training")
    
    # Memory recommendations
    batch_size = 16  # From config
    estimated_batch_memory = raw_images.nbytes * batch_size / len(files) / 1024**2
    print(f"   📊 Estimated batch memory (size={batch_size}): {estimated_batch_memory:.1f} MB")
    
    if estimated_batch_memory > 1000:  # > 1GB
        print(f"   ⚠️  Consider reducing batch size or enabling PCA preprocessing")
    else:
        print(f"   ✓  Memory usage looks reasonable for training")
    
    print(f"\\n" + "="*60)
    print("Ready for latent conditioner training! 🚀")
    print("="*60)

## Test PCA Settings

Interactive cell to test different PCA configurations:

In [None]:
# Interactive PCA testing
def test_pca_config(components, patch_size=0):
    """Test different PCA configurations"""
    if 'raw_images' not in locals() or len(files) == 0:
        print("❌ No image data loaded")
        return
    
    max_components = min(len(files), DEFAULT_IMAGE_SIZE**2)
    if components >= max_components:
        print(f"❌ Components ({components}) must be < {max_components}")
        return
    
    print(f"Testing PCA with {components} components, patch_size={patch_size}")
    
    try:
        test_pca = PCAPreprocessor(
            n_components=components,
            patch_size=patch_size if patch_size > 0 else None
        )
        test_pca.fit(raw_images)
        test_tensor = test_pca.transform(raw_images)
        
        if hasattr(test_pca.pca, 'explained_variance_ratio_'):
            variance_explained = np.sum(test_pca.pca.explained_variance_ratio_)
            print(f"✓ Success! Explained variance: {variance_explained:.1%}")
            print(f"  Output shape: {test_tensor.shape}")
            print(f"  Memory reduction: {100 * (1 - test_tensor.numel() * 4 / raw_images.nbytes):.1f}%")
        else:
            print(f"✓ Success! Output shape: {test_tensor.shape}")
            
    except Exception as e:
        print(f"❌ Error: {e}")

# Test different configurations
if 'raw_images' in locals() and len(files) > 0:
    print("Testing different PCA configurations:")
    test_configs = [64, 128, 256, 400]
    
    for comp in test_configs:
        if comp < len(files):
            test_pca_config(comp)
        else:
            print(f"Skipping {comp} components (exceeds sample count)")

## PCA Reconstruction Analysis

This section shows how well PCA can reconstruct the original images with different numbers of components:

In [None]:
# PCA Reconstruction Analysis
def reconstruct_with_pca(images, n_components, patch_size=None):
    """Reconstruct images using PCA with specified number of components"""
    try:
        # Create PCA preprocessor
        pca_proc = PCAPreprocessor(
            n_components=n_components,
            patch_size=patch_size
        )
        
        # Fit and transform
        pca_proc.fit(images)
        
        # For reconstruction, we need to use sklearn PCA directly
        if patch_size is None:  # Full image PCA
            # Flatten images
            images_flat = images.reshape(images.shape[0], -1)
            
            # Transform to PCA space and back
            pca_coeffs = pca_proc.pca.transform(images_flat)
            reconstructed_flat = pca_proc.pca.inverse_transform(pca_coeffs)
            
            # Reshape back to images
            reconstructed = reconstructed_flat.reshape(images.shape)
            
            # Calculate reconstruction error
            mse = np.mean((images - reconstructed) ** 2)
            explained_var = np.sum(pca_proc.pca.explained_variance_ratio_)
            
            return reconstructed, mse, explained_var, pca_proc
        else:
            # For patch-based PCA, reconstruction is more complex
            # We'll implement a simplified version
            height, width = images.shape[1], images.shape[2]
            n_samples = images.shape[0]
            
            patches_per_dim = height // patch_size
            reconstructed = np.zeros_like(images)
            
            for sample_idx in range(n_samples):
                img = images[sample_idx]
                recon_img = np.zeros_like(img)
                
                for i in range(patches_per_dim):
                    for j in range(patches_per_dim):
                        # Extract patch
                        patch = img[i*patch_size:(i+1)*patch_size, 
                                   j*patch_size:(j+1)*patch_size]
                        patch_flat = patch.flatten().reshape(1, -1)
                        
                        # Transform and reconstruct patch
                        patch_pca = pca_proc.pca.transform(patch_flat)
                        patch_recon = pca_proc.pca.inverse_transform(patch_pca)
                        
                        # Place reconstructed patch back
                        recon_img[i*patch_size:(i+1)*patch_size, 
                                 j*patch_size:(j+1)*patch_size] = patch_recon.reshape(patch_size, patch_size)
                
                reconstructed[sample_idx] = recon_img
            
            mse = np.mean((images - reconstructed) ** 2)
            explained_var = np.sum(pca_proc.pca.explained_variance_ratio_)
            
            return reconstructed, mse, explained_var, pca_proc
            
    except Exception as e:
        print(f"Error in PCA reconstruction with {n_components} components: {e}")
        return None, float('inf'), 0, None

# Test different PCA configurations for reconstruction
if 'raw_images' in locals() and len(files) > 0:
    print("\\n=== PCA Reconstruction Analysis ===")
    
    # Test different numbers of components
    component_configs = [16, 32, 64, 128, 256]
    max_components = min(len(files) - 1, DEFAULT_IMAGE_SIZE**2)
    
    # Filter valid configurations
    valid_configs = [c for c in component_configs if c < max_components]
    
    print(f"Testing PCA reconstruction with components: {valid_configs}")
    print(f"Maximum possible components: {max_components}")
    
    # Store results
    reconstruction_results = {}
    
    for n_comp in valid_configs:
        print(f"\\nTesting {n_comp} components...")
        reconstructed, mse, explained_var, pca_proc = reconstruct_with_pca(raw_images, n_comp)
        
        if reconstructed is not None:
            reconstruction_results[n_comp] = {
                'reconstructed': reconstructed,
                'mse': mse,
                'explained_variance': explained_var,
                'pca_processor': pca_proc
            }
            print(f"  MSE: {mse:.2e}")
            print(f"  Explained variance: {explained_var:.1%}")
        else:
            print(f"  Failed to reconstruct with {n_comp} components")
    
    print(f"\\n✓ Reconstruction analysis complete for {len(reconstruction_results)} configurations")

In [None]:
# Visualize PCA reconstructions
if 'reconstruction_results' in locals() and len(reconstruction_results) > 0:
    print("\\n=== PCA Reconstruction Visualization ===")
    
    # Select first few images for comparison
    images_to_show = min(3, len(files))
    configs_to_show = list(reconstruction_results.keys())
    
    # Create a comprehensive comparison plot
    fig, axes = plt.subplots(len(configs_to_show) + 1, images_to_show, 
                            figsize=(5*images_to_show, 3*(len(configs_to_show) + 1)))
    
    if images_to_show == 1:
        axes = axes.reshape(-1, 1)
    
    # Show original images in the first row
    for img_idx in range(images_to_show):
        axes[0, img_idx].imshow(raw_images[img_idx], cmap='gray', vmin=0, vmax=255)
        axes[0, img_idx].set_title(f'Original Image {img_idx+1}')
        axes[0, img_idx].axis('off')
    
    # Show reconstructions for each PCA configuration
    for config_idx, (n_comp, results) in enumerate(reconstruction_results.items()):
        row_idx = config_idx + 1
        reconstructed = results['reconstructed']
        mse = results['mse']
        explained_var = results['explained_variance']
        
        for img_idx in range(images_to_show):
            # Clip reconstructed values to valid range
            recon_img = np.clip(reconstructed[img_idx], 0, 255)
            
            axes[row_idx, img_idx].imshow(recon_img, cmap='gray', vmin=0, vmax=255)
            axes[row_idx, img_idx].set_title(f'{n_comp} components\\nMSE: {mse:.1e}, Var: {explained_var:.1%}')
            axes[row_idx, img_idx].axis('off')
    
    plt.suptitle('PCA Reconstruction Comparison', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.show()
    
    # Plot reconstruction quality metrics
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # MSE vs Components
    components = list(reconstruction_results.keys())
    mse_values = [reconstruction_results[c]['mse'] for c in components]
    explained_vars = [reconstruction_results[c]['explained_variance'] for c in components]
    
    ax1.semilogy(components, mse_values, 'bo-', markersize=8, linewidth=2)
    ax1.set_xlabel('Number of PCA Components')
    ax1.set_ylabel('Mean Squared Error (log scale)')
    ax1.set_title('Reconstruction Error vs PCA Components')
    ax1.grid(True, alpha=0.3)
    
    # Add MSE values as text
    for i, (comp, mse) in enumerate(zip(components, mse_values)):
        ax1.annotate(f'{mse:.1e}', (comp, mse), textcoords=\"offset points\", 
                    xytext=(0,10), ha='center', fontsize=8)
    
    # Explained Variance vs Components
    ax2.plot(components, [v*100 for v in explained_vars], 'ro-', markersize=8, linewidth=2)
    ax2.set_xlabel('Number of PCA Components')
    ax2.set_ylabel('Explained Variance (%)')
    ax2.set_title('Explained Variance vs PCA Components')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 100)
    
    # Add percentage values as text
    for i, (comp, var) in enumerate(zip(components, explained_vars)):
        ax2.annotate(f'{var*100:.1f}%', (comp, var*100), textcoords=\"offset points\", 
                    xytext=(0,10), ha='center', fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    # Show detailed error analysis for the first image
    print(\"\\nDetailed Error Analysis (First Image):\")
    print(\"=\"*50)
    
    original_img = raw_images[0]
    
    for n_comp, results in reconstruction_results.items():
        recon_img = results['reconstructed'][0]
        
        # Calculate various error metrics
        mse = np.mean((original_img - recon_img) ** 2)
        mae = np.mean(np.abs(original_img - recon_img))
        max_err = np.max(np.abs(original_img - recon_img))
        
        # Calculate PSNR (Peak Signal-to-Noise Ratio)
        if mse > 0:
            psnr = 20 * np.log10(255.0 / np.sqrt(mse))
        else:
            psnr = float('inf')
        
        # Calculate SSIM (Structural Similarity Index) - simplified version
        mu1, mu2 = original_img.mean(), recon_img.mean()
        sigma1, sigma2 = original_img.std(), recon_img.std()
        sigma12 = np.mean((original_img - mu1) * (recon_img - mu2))
        
        c1, c2 = (0.01 * 255) ** 2, (0.03 * 255) ** 2
        ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1**2 + sigma2**2 + c2))
        
        print(f\"{n_comp:3d} components: MSE={mse:8.1e}, MAE={mae:6.1f}, Max_Err={max_err:6.1f}, PSNR={psnr:5.1f}dB, SSIM={ssim:5.3f}\")"
   ]

## Patch-Based PCA Reconstruction

Test patch-based PCA reconstruction with different patch sizes:

In [None]:
# Test patch-based PCA reconstruction
if 'raw_images' in locals() and len(files) > 0:
    print("\\n=== Patch-Based PCA Reconstruction ===")
    
    # Test different patch sizes (must divide image size evenly)
    image_size = DEFAULT_IMAGE_SIZE  # 256
    possible_patch_sizes = [16, 32, 64, 128]  # All divide 256 evenly
    
    # Filter patch sizes that work with our image size
    valid_patch_sizes = [p for p in possible_patch_sizes if image_size % p == 0]
    
    print(f\"Testing patch sizes: {valid_patch_sizes} (image size: {image_size}x{image_size})\")
    
    # Store patch-based reconstruction results
    patch_reconstruction_results = {}
    
    for patch_size in valid_patch_sizes:
        print(f\"\\nTesting patch size {patch_size}x{patch_size}...\")
        
        # Calculate components per patch (use reasonable number)
        patches_per_dim = image_size // patch_size
        total_patches = patches_per_dim * patches_per_dim
        patch_pixels = patch_size * patch_size
        
        # Use smaller number of components for patches
        components_per_patch = min(32, patch_pixels // 4, len(files) // 2)
        
        print(f\"  Patches per image: {total_patches} ({patches_per_dim}x{patches_per_dim})\")
        print(f\"  Pixels per patch: {patch_pixels}\")
        print(f\"  Components per patch: {components_per_patch}\")
        
        try:
            reconstructed, mse, explained_var, pca_proc = reconstruct_with_pca(
                raw_images, components_per_patch, patch_size=patch_size
            )
            
            if reconstructed is not None:
                patch_reconstruction_results[patch_size] = {
                    'reconstructed': reconstructed,
                    'mse': mse,
                    'explained_variance': explained_var,
                    'components_per_patch': components_per_patch,
                    'total_patches': total_patches
                }
                print(f\"  ✓ Success! MSE: {mse:.2e}, Explained variance: {explained_var:.1%}\")\n            else:\n                print(f\"  ❌ Failed to reconstruct with patch size {patch_size}\")\n                \n        except Exception as e:\n            print(f\"  ❌ Error with patch size {patch_size}: {e}\")\n    \n    print(f\"\\n✓ Patch-based reconstruction analysis complete for {len(patch_reconstruction_results)} configurations\")"

In [None]:
# Visualize patch-based PCA reconstructions
if 'patch_reconstruction_results' in locals() and len(patch_reconstruction_results) > 0:
    print("\\n=== Patch-Based PCA Visualization ===")
    
    # Compare original vs different patch-based reconstructions
    images_to_show = min(2, len(files))
    patch_configs = list(patch_reconstruction_results.keys())
    
    fig, axes = plt.subplots(len(patch_configs) + 1, images_to_show, 
                            figsize=(6*images_to_show, 3*(len(patch_configs) + 1)))
    
    if images_to_show == 1:
        axes = axes.reshape(-1, 1)
    
    # Show original images in the first row
    for img_idx in range(images_to_show):
        axes[0, img_idx].imshow(raw_images[img_idx], cmap='gray', vmin=0, vmax=255)
        axes[0, img_idx].set_title(f'Original Image {img_idx+1}')
        axes[0, img_idx].axis('off')
    
    # Show patch-based reconstructions
    for config_idx, (patch_size, results) in enumerate(patch_reconstruction_results.items()):
        row_idx = config_idx + 1
        reconstructed = results['reconstructed']
        mse = results['mse']
        components_per_patch = results['components_per_patch']
        
        for img_idx in range(images_to_show):
            recon_img = np.clip(reconstructed[img_idx], 0, 255)
            
            axes[row_idx, img_idx].imshow(recon_img, cmap='gray', vmin=0, vmax=255)
            axes[row_idx, img_idx].set_title(f'Patch {patch_size}x{patch_size}\\n{components_per_patch} comp/patch, MSE: {mse:.1e}')
            axes[row_idx, img_idx].axis('off')
            
            # Draw patch grid overlay
            patches_per_dim = DEFAULT_IMAGE_SIZE // patch_size
            for i in range(1, patches_per_dim):
                # Vertical lines
                axes[row_idx, img_idx].axvline(x=i*patch_size-0.5, color='red', alpha=0.3, linewidth=0.5)
                # Horizontal lines  
                axes[row_idx, img_idx].axhline(y=i*patch_size-0.5, color='red', alpha=0.3, linewidth=0.5)
    
    plt.suptitle('Patch-Based PCA Reconstruction Comparison', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.show()
    
    # Create comparison between full-image PCA and patch-based PCA
    if 'reconstruction_results' in locals() and len(reconstruction_results) > 0:
        print("\\n=== Full-Image vs Patch-Based PCA Comparison ===")
        
        # Create comparison plot
        fig, axes = plt.subplots(2, 3, figsize=(15, 8))
        
        # Show original
        axes[0, 0].imshow(raw_images[0], cmap='gray')
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        # Show best full-image PCA reconstruction (highest components)
        if reconstruction_results:
            best_full_comp = max(reconstruction_results.keys())
            best_full_recon = reconstruction_results[best_full_comp]['reconstructed'][0]
            best_full_mse = reconstruction_results[best_full_comp]['mse']
            
            axes[0, 1].imshow(np.clip(best_full_recon, 0, 255), cmap='gray')
            axes[0, 1].set_title(f'Full-Image PCA\\n{best_full_comp} components\\nMSE: {best_full_mse:.1e}')
            axes[0, 1].axis('off')
        
        # Show best patch-based PCA reconstruction
        if patch_reconstruction_results:
            best_patch_size = min(patch_reconstruction_results.keys(), key=lambda x: patch_reconstruction_results[x]['mse'])
            best_patch_recon = patch_reconstruction_results[best_patch_size]['reconstructed'][0]
            best_patch_mse = patch_reconstruction_results[best_patch_size]['mse']
            best_patch_comp = patch_reconstruction_results[best_patch_size]['components_per_patch']
            
            axes[0, 2].imshow(np.clip(best_patch_recon, 0, 255), cmap='gray')
            axes[0, 2].set_title(f'Patch-Based PCA\\n{best_patch_size}x{best_patch_size}, {best_patch_comp} comp/patch\\nMSE: {best_patch_mse:.1e}')
            axes[0, 2].axis('off')
        
        # Show error maps
        if reconstruction_results and patch_reconstruction_results:
            # Full-image error
            full_error = np.abs(raw_images[0] - best_full_recon)
            im1 = axes[1, 1].imshow(full_error, cmap='hot', vmin=0, vmax=np.max(full_error))
            axes[1, 1].set_title(f'Full-Image Error\\nMax: {np.max(full_error):.1f}')
            axes[1, 1].axis('off')
            plt.colorbar(im1, ax=axes[1, 1], fraction=0.046, pad=0.04)
            
            # Patch-based error
            patch_error = np.abs(raw_images[0] - best_patch_recon)
            im2 = axes[1, 2].imshow(patch_error, cmap='hot', vmin=0, vmax=np.max(patch_error))
            axes[1, 2].set_title(f'Patch-Based Error\\nMax: {np.max(patch_error):.1f}')
            axes[1, 2].axis('off')
            plt.colorbar(im2, ax=axes[1, 2], fraction=0.046, pad=0.04)
        
        # Hide unused subplot
        axes[1, 0].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Summary comparison table
        print(\"\\nPCA Mode Comparison Summary:\")
        print(\"=\"*60)
        print(f\"{'Mode':<20} {'Components':<12} {'MSE':<12} {'PSNR (dB)':<10}\")\n        print(\"-\"*60)\n        \n        if reconstruction_results:\n            for comp, results in sorted(reconstruction_results.items()):\n                mse = results['mse']\n                psnr = 20 * np.log10(255.0 / np.sqrt(mse)) if mse > 0 else float('inf')\n                print(f\"{'Full-Image':<20} {comp:<12} {mse:<12.2e} {psnr:<10.1f}\")\n                \n        if patch_reconstruction_results:\n            for patch_size, results in sorted(patch_reconstruction_results.items()):\n                mse = results['mse']\n                comp_per_patch = results['components_per_patch']\n                total_patches = results['total_patches']\n                effective_comp = comp_per_patch * total_patches\n                psnr = 20 * np.log10(255.0 / np.sqrt(mse)) if mse > 0 else float('inf')\n                mode_name = f\"Patch {patch_size}x{patch_size}\"\n                comp_desc = f\"{comp_per_patch}x{total_patches}={effective_comp}\"\n                print(f\"{mode_name:<20} {comp_desc:<12} {mse:<12.2e} {psnr:<10.1f}\")