# Latent Conditioner Input Data Preview

This notebook previews the input data for the latent conditioner using the same settings as `read_latent_conditioner_dataset_img` and demonstrates PCA preprocessing effects.

In [None]:
# Import required libraries
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import natsort
from sklearn.decomposition import PCA
import math
import torch

# Import PCA preprocessor
import sys
sys.path.append('modules')
from pca_preprocessor import PCAPreprocessor

# Set matplotlib parameters for better plots
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

## Configuration Settings

These settings match the configuration from `input_data/condition.txt`:

In [None]:
# Configuration from condition.txt
param_dir = '/images'  # Image directory
param_data_type = '.png'  # Image file type
DEFAULT_IMAGE_SIZE = 256  # High resolution for sharp outline detection
INTERPOLATION_METHOD = cv2.INTER_CUBIC  # High-quality interpolation

# PCA settings
use_pca = False  # Set to True to enable PCA preprocessing
pca_components = 256  # Number of PCA components (must be < n_samples)
pca_patch_size = 0  # 0=full image PCA, >0=patch-based PCA

# Other settings
n_sample = 484  # Number of samples
debug_mode = 1  # Enable debug output

print(f"Configuration:")
print(f"  Image directory: {param_dir}")
print(f"  Image type: {param_data_type}")
print(f"  Image size: {DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE}")
print(f"  Expected samples: {n_sample}")
print(f"  PCA enabled: {use_pca}")
print(f"  PCA components: {pca_components}")
print(f"  PCA patch size: {pca_patch_size} (0=full image)")

## Load and Preview Raw Images

This section replicates the image loading from `read_latent_conditioner_dataset_img`:

In [None]:
# Load images using the same method as read_latent_conditioner_dataset_img
cur_dir = os.getcwd()
file_dir = cur_dir + param_dir

print(f"Looking for images in: {file_dir}")

if not os.path.exists(file_dir):
    print(f"❌ Directory {file_dir} does not exist!")
    print("Please ensure the images directory exists and contains .png files")
else:
    # Get list of image files
    files = [f for f in os.listdir(file_dir) if f.endswith(param_data_type)]
    files = natsort.natsorted(files)
    
    print(f"Found {len(files)} image files")
    
    if len(files) == 0:
        print(f"❌ No {param_data_type} files found in {file_dir}")
    else:
        print(f"First 5 files: {files[:5]}")
        
        # Load first few images for preview
        preview_count = min(len(files), 6)
        raw_images = np.zeros((len(files), DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE))
        
        for i, file in enumerate(files):
            if debug_mode == 1 and i < 5:
                print(f"Loading: {file}")
            file_path = os.path.join(file_dir, file)
            im = cv2.imread(file_path, 0)  # Grayscale
            if im is None:
                print(f"❌ Failed to load {file}")
                continue
            resized_im = cv2.resize(im, (DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE), 
                                  interpolation=INTERPOLATION_METHOD)
            raw_images[i] = resized_im
        
        print(f"✓ Loaded {len(files)} images with shape: {raw_images.shape}")
        print(f"  Image value range: [{raw_images.min():.1f}, {raw_images.max():.1f}]")
        print(f"  Image dtype: {raw_images.dtype}")

In [None]:
# Display first few raw images
if 'raw_images' in locals() and len(files) > 0:
    preview_count = min(len(files), 6)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Raw Input Images (First 6 samples)', fontsize=16)
    
    for i in range(preview_count):
        row = i // 3
        col = i % 3
        
        axes[row, col].imshow(raw_images[i], cmap='gray')
        axes[row, col].set_title(f'Image {i+1}: {files[i]}')
        axes[row, col].axis('off')
        
        # Add statistics
        img_mean = raw_images[i].mean()
        img_std = raw_images[i].std()
        axes[row, col].text(0.02, 0.98, f'μ={img_mean:.1f}\\nσ={img_std:.1f}', 
                           transform=axes[row, col].transAxes, 
                           verticalalignment='top',
                           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Hide empty subplots
    for i in range(preview_count, 6):
        row = i // 3
        col = i % 3
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show histogram of pixel values
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(raw_images.flatten(), bins=50, alpha=0.7, edgecolor='black')
    plt.title('Pixel Value Distribution (All Images)')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    # Show per-image statistics
    img_means = [raw_images[i].mean() for i in range(min(20, len(files)))]
    img_stds = [raw_images[i].std() for i in range(min(20, len(files)))]
    
    x_pos = range(len(img_means))
    plt.errorbar(x_pos, img_means, yerr=img_stds, fmt='o-', capsize=3)
    plt.title('Per-Image Statistics (First 20 images)')
    plt.xlabel('Image Index')
    plt.ylabel('Pixel Value')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## PCA Preprocessing Analysis

This section demonstrates the PCA preprocessing using the `PCAPreprocessor` class:

In [None]:
# PCA Analysis and Preprocessing
if 'raw_images' in locals() and len(files) > 0:
    print("\\n=== PCA Preprocessing Analysis ===")
    
    # Check if we can apply PCA with current settings
    n_samples = raw_images.shape[0]
    n_features = DEFAULT_IMAGE_SIZE * DEFAULT_IMAGE_SIZE
    max_components = min(n_samples, n_features)
    
    print(f"Dataset info:")
    print(f"  Number of samples: {n_samples}")
    print(f"  Features per image: {n_features} ({DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE})")
    print(f"  Maximum PCA components: {max_components}")
    print(f"  Requested PCA components: {pca_components}")
    
    # Adjust PCA components if necessary
    if pca_components > max_components:
        print(f"❌ Requested {pca_components} components exceeds maximum {max_components}")
        adjusted_components = min(256, max_components - 1)  # Leave some margin
        print(f"✓ Adjusting to {adjusted_components} components")
        pca_components = adjusted_components
    
    # Initialize PCA preprocessor
    pca_preprocessor = PCAPreprocessor(
        n_components=pca_components,
        patch_size=pca_patch_size if pca_patch_size > 0 else None
    )
    
    print(f"\\nPCA Configuration:")
    print(f"  Components: {pca_components}")
    print(f"  Patch size: {pca_patch_size if pca_patch_size > 0 else 'Full image'}")
    
    # Fit PCA on the data
    print("\\nFitting PCA model...")
    pca_preprocessor.fit(raw_images)
    
    # Transform images using PCA
    print("\\nTransforming images with PCA...")
    pca_tensor = pca_preprocessor.transform(raw_images)
    
    print(f"PCA transformation results:")
    print(f"  Original shape: {raw_images.shape}")
    print(f"  PCA tensor shape: {pca_tensor.shape}")
    print(f"  PCA tensor dtype: {pca_tensor.dtype}")
    print(f"  Output shape: {pca_preprocessor.get_output_shape()}")
    print(f"  Output channels: {pca_preprocessor.get_output_channels()}")
    
    # Convert to format expected by latent conditioner
    if len(pca_tensor.shape) == 4:  # (n_samples, channels, height, width)
        pca_data_flattened = pca_tensor.view(pca_tensor.shape[0], -1).numpy()
        data_shape = pca_tensor.shape[2:]  # (height, width)
    else:
        pca_data_flattened = pca_tensor.numpy()
        data_shape = pca_preprocessor.get_output_shape()
    
    print(f"\\nFinal processed data:")
    print(f"  Flattened shape: {pca_data_flattened.shape}")
    print(f"  Data shape for model: {data_shape}")
    print(f"  Dimensionality reduction: {n_features} → {pca_data_flattened.shape[1]} ({100*pca_data_flattened.shape[1]/n_features:.1f}%)")
    
    # Show explained variance
    if hasattr(pca_preprocessor.pca, 'explained_variance_ratio_'):
        cumulative_variance = np.cumsum(pca_preprocessor.pca.explained_variance_ratio_)
        print(f"  Explained variance: {cumulative_variance[-1]:.1%}")
        
        # Plot explained variance
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, 'b-o', markersize=3)
        plt.xlabel('Number of Components')
        plt.ylabel('Cumulative Explained Variance')
        plt.title('PCA Explained Variance')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0.95, color='r', linestyle='--', alpha=0.7, label='95% threshold')
        plt.axhline(y=0.99, color='g', linestyle='--', alpha=0.7, label='99% threshold')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(range(1, min(50, len(pca_preprocessor.pca.explained_variance_ratio_)) + 1), 
                pca_preprocessor.pca.explained_variance_ratio_[:min(50, len(pca_preprocessor.pca.explained_variance_ratio_))], 
                'r-o', markersize=3)
        plt.xlabel('Component Number')
        plt.ylabel('Explained Variance Ratio')
        plt.title('Individual Component Variance (First 50)')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

In [None]:
# Visualize PCA components and reconstructions
if 'pca_preprocessor' in locals() and pca_preprocessor.is_fitted:
    print("\\n=== PCA Visualization ===")
    
    # Show first few PCA components as images
    if not pca_preprocessor.patch_size:  # Full image PCA
        components_to_show = min(6, pca_preprocessor.pca.components_.shape[0])
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle('First 6 PCA Components', fontsize=16)
        
        for i in range(components_to_show):
            row = i // 3
            col = i % 3
            
            # Reshape component to image
            component_img = pca_preprocessor.pca.components_[i].reshape(DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE)
            
            # Normalize for display
            component_img = (component_img - component_img.min()) / (component_img.max() - component_img.min())
            
            axes[row, col].imshow(component_img, cmap='RdBu_r')
            axes[row, col].set_title(f'Component {i+1}\\n(Var: {pca_preprocessor.pca.explained_variance_ratio_[i]:.3f})')
            axes[row, col].axis('off')
        
        # Hide empty subplots
        for i in range(components_to_show, 6):
            row = i // 3
            col = i % 3
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    # Show original vs PCA-processed images
    print("\\nComparing original vs PCA-processed images:")
    
    # Take first few samples for comparison
    samples_to_compare = min(3, len(files))
    
    fig, axes = plt.subplots(2, samples_to_compare, figsize=(5*samples_to_compare, 8))
    if samples_to_compare == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(samples_to_compare):
        # Original image
        axes[0, i].imshow(raw_images[i], cmap='gray')
        axes[0, i].set_title(f'Original Image {i+1}')
        axes[0, i].axis('off')
        
        # PCA-processed visualization
        if len(pca_tensor.shape) == 4:  # (n_samples, channels, height, width)
            pca_img = pca_tensor[i, 0].numpy()  # Take first channel
        else:
            # Reshape flattened PCA data for visualization
            pca_data_2d = pca_data_flattened[i]
            if len(data_shape) == 2:
                pca_img = pca_data_2d.reshape(data_shape)
            else:
                # For 1D data, create a simple visualization
                side_len = int(np.sqrt(len(pca_data_2d)))
                if side_len * side_len == len(pca_data_2d):
                    pca_img = pca_data_2d.reshape(side_len, side_len)
                else:
                    # Pad to make square
                    target_len = side_len + 1
                    padded_data = np.pad(pca_data_2d, (0, target_len*target_len - len(pca_data_2d)))
                    pca_img = padded_data.reshape(target_len, target_len)
        
        axes[1, i].imshow(pca_img, cmap='viridis')
        axes[1, i].set_title(f'PCA Processed {i+1}\\n({pca_components} components)')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\\n✓ PCA preprocessing complete!")
    print(f"  Original data: {raw_images.shape} → {raw_images.nbytes / 1024**2:.1f} MB")
    print(f"  PCA data: {pca_data_flattened.shape} → {pca_data_flattened.nbytes / 1024**2:.1f} MB")
    print(f"  Memory reduction: {100 * (1 - pca_data_flattened.nbytes / raw_images.nbytes):.1f}%")

## Data Summary and Recommendations

Final analysis and recommendations for the latent conditioner training:

In [None]:
# Summary and recommendations
if 'raw_images' in locals() and len(files) > 0:
    print("\\n" + "="*60)
    print("         LATENT CONDITIONER INPUT SUMMARY")
    print("="*60)
    
    print(f"\\n📊 Dataset Information:")
    print(f"   • Total samples: {len(files)}")
    print(f"   • Image size: {DEFAULT_IMAGE_SIZE}x{DEFAULT_IMAGE_SIZE}")
    print(f"   • Image format: {param_data_type}")
    print(f"   • Directory: {file_dir}")
    
    print(f"\\n🔢 Data Statistics:")
    print(f"   • Pixel value range: [{raw_images.min():.1f}, {raw_images.max():.1f}]")
    print(f"   • Mean pixel value: {raw_images.mean():.1f}")
    print(f"   • Std pixel value: {raw_images.std():.1f}")
    print(f"   • Total memory: {raw_images.nbytes / 1024**2:.1f} MB")
    
    if 'pca_preprocessor' in locals() and pca_preprocessor.is_fitted:
        print(f"\\n🔍 PCA Analysis:")
        print(f"   • PCA components: {pca_components}")
        print(f"   • Explained variance: {np.sum(pca_preprocessor.pca.explained_variance_ratio_):.1%}")
        print(f"   • Dimensionality reduction: {100 * (1 - pca_data_flattened.shape[1] / (DEFAULT_IMAGE_SIZE**2)):.1f}%")
        print(f"   • Memory reduction: {100 * (1 - pca_data_flattened.nbytes / raw_images.nbytes):.1f}%")
        print(f"   • Final data shape: {pca_data_flattened.shape}")
    
    print(f"\\n⚙️ Configuration Recommendations:")
    
    # Check sample size vs PCA components
    if pca_components >= len(files):
        recommended_components = min(len(files) // 2, 256)
        print(f"   ⚠️  PCA components ({pca_components}) should be < samples ({len(files)})")
        print(f"   ✓  Recommend: pca_components = {recommended_components}")
    else:
        print(f"   ✓  PCA components ({pca_components}) < samples ({len(files)}) ✓")
    
    # Check if images are outline/edge data
    edge_content = np.mean([cv2.Canny(raw_images[i].astype(np.uint8), 50, 150).sum() for i in range(min(10, len(files)))])
    if edge_content > raw_images.size / len(files) * 0.01:  # Heuristic for outline detection
        print(f"   ✓  Images appear to contain outline/edge data")
        print(f"   ✓  Outline-preserving augmentations will be applied during training")
    
    # Memory recommendations
    batch_size = 16  # From config
    estimated_batch_memory = raw_images.nbytes * batch_size / len(files) / 1024**2
    print(f"   📊 Estimated batch memory (size={batch_size}): {estimated_batch_memory:.1f} MB")
    
    if estimated_batch_memory > 1000:  # > 1GB
        print(f"   ⚠️  Consider reducing batch size or enabling PCA preprocessing")
    else:
        print(f"   ✓  Memory usage looks reasonable for training")
    
    print(f"\\n" + "="*60)
    print("Ready for latent conditioner training! 🚀")
    print("="*60)

## Test PCA Settings

Interactive cell to test different PCA configurations:

In [None]:
# Interactive PCA testing
def test_pca_config(components, patch_size=0):
    """Test different PCA configurations"""
    if 'raw_images' not in locals() or len(files) == 0:
        print("❌ No image data loaded")
        return
    
    max_components = min(len(files), DEFAULT_IMAGE_SIZE**2)
    if components >= max_components:
        print(f"❌ Components ({components}) must be < {max_components}")
        return
    
    print(f"Testing PCA with {components} components, patch_size={patch_size}")
    
    try:
        test_pca = PCAPreprocessor(
            n_components=components,
            patch_size=patch_size if patch_size > 0 else None
        )
        test_pca.fit(raw_images)
        test_tensor = test_pca.transform(raw_images)
        
        if hasattr(test_pca.pca, 'explained_variance_ratio_'):
            variance_explained = np.sum(test_pca.pca.explained_variance_ratio_)
            print(f"✓ Success! Explained variance: {variance_explained:.1%}")
            print(f"  Output shape: {test_tensor.shape}")
            print(f"  Memory reduction: {100 * (1 - test_tensor.numel() * 4 / raw_images.nbytes):.1f}%")
        else:
            print(f"✓ Success! Output shape: {test_tensor.shape}")
            
    except Exception as e:
        print(f"❌ Error: {e}")

# Test different configurations
if 'raw_images' in locals() and len(files) > 0:
    print("Testing different PCA configurations:")
    test_configs = [64, 128, 256, 400]
    
    for comp in test_configs:
        if comp < len(files):
            test_pca_config(comp)
        else:
            print(f"Skipping {comp} components (exceeds sample count)")