In [1]:
model_path = "/home/mhuber/Thesis/GitHub/maisi/models/VAE/lima_best.pt"
image_path = "/home/mhuber/Thesis/data/KermanyV3_resized/test/1/DME-145590-3.jpeg"

In [3]:
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import os
from networks.autoencoderkl_maisi import AutoencoderKlMaisi

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration for the autoencoder from your model
config = {
    "spatial_dims": 2,
    "in_channels": 1,
    "out_channels": 1,
    "latent_channels": 4,
    "num_channels": [64, 128, 256],
    "num_res_blocks": [2, 2, 2],
    "norm_num_groups": 32,
    "norm_eps": 1e-6,
    "attention_levels": [False, False, False],
    "with_encoder_nonlocal_attn": False,
    "with_decoder_nonlocal_attn": False,
    "use_checkpointing": False,
    "use_convtranspose": False,
    "norm_float16": True,
    "num_splits": 1,
    "dim_split": 1
}

# Create output directory
output_dir = "autoencoder_results"
os.makedirs(output_dir, exist_ok=True)

# Load and preprocess image
def load_image(path, target_size=(256, 256)):
    """Load and preprocess an image for the autoencoder"""
    img = Image.open(path)
    
    # Convert to grayscale if input channel is 1
    if config["in_channels"] == 1:
        img = img.convert("L")
    else:
        img = img.convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * config["in_channels"], [0.5] * config["in_channels"])
    ])
    
    # Create tensor on device
    return transform(img).unsqueeze(0).to(device, dtype=torch.float32)

# Load model
def load_model(path):
    """Load model from checkpoint, handling nested state dict"""
    # Create model
    print("Creating model...")
    model = AutoencoderKlMaisi(**config).to(device)
    
    # Load checkpoint with safe loading
    print("Loading checkpoint...")
    checkpoint = torch.load(path, map_location=device, weights_only=True)
    
    # Handle different checkpoint formats
    print("Available keys in checkpoint:", list(checkpoint.keys()))
    
    if "autoencoder_state_dict" in checkpoint:
        print("Found autoencoder_state_dict in checkpoint")
        model.load_state_dict(checkpoint["autoencoder_state_dict"])
    elif "model" in checkpoint:
        print("Found model in checkpoint")
        model.load_state_dict(checkpoint["model"])
    elif "state_dict" in checkpoint:
        print("Found state_dict in checkpoint")
        model.load_state_dict(checkpoint["state_dict"])
    else:
        raise ValueError("Cannot find model weights in checkpoint")
    
    print("Model loaded successfully")
    
    # Set to evaluation mode
    model.eval()
    
    return model

# Basic image processing function that doesn't require complicated SPADE segmentation
def process_image():
    """Process an image with the autoencoder"""
    try:
        print("Checking if files exist...")
        if not os.path.exists(model_path):
            print(f"Model file not found: {model_path}")
            return
        
        if not os.path.exists(image_path):
            print(f"Image file not found: {image_path}")
            return
        
        print("Loading model...")
        model = load_model(model_path)
        
        print("Loading image...")
        img = load_image(image_path)
        print(f"Image shape: {img.shape}, dtype: {img.dtype}")
        
        # Get filename without extension for saving results
        filename_prefix = os.path.splitext(os.path.basename(image_path))[0]
        
        print("Processing image with autoencoder...")
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=device.type=='cuda'):
                # Encode the image
                z_mu, z_sigma = model.encode(img)
                latent = model.sampling(z_mu, z_sigma)
                print(f"Latent shape: {latent.shape}")
                
                # Decode the latent
                reconstructed = model.decode(latent)
                print(f"Reconstructed shape: {reconstructed.shape}")
        
        # Visualize and save results
        visualize_results(img, latent, reconstructed, filename_prefix)
        
        print(f"All results saved to {output_dir}")
        
    except Exception as e:
        print(f"Error processing image: {e}")
        import traceback
        traceback.print_exc()

def visualize_results(original, latent, reconstructed, filename_prefix):
    """Visualize and save the results"""
    # 1. Save original and reconstructed images
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Original image
    orig_img = (original.squeeze().cpu().numpy() + 1) / 2
    axes[0].imshow(orig_img, cmap='gray' if config["in_channels"] == 1 else None)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # Reconstructed image
    recon_img = (reconstructed.squeeze().cpu().numpy() + 1) / 2
    axes[1].imshow(recon_img, cmap='gray' if config["in_channels"] == 1 else None)
    axes[1].set_title('Reconstructed Image')
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{filename_prefix}_reconstruction.png", dpi=300)
    plt.close()
    
    # 2. Visualize latent channels
    latent_np = latent.squeeze().cpu().numpy()
    
    fig, axes = plt.subplots(1, latent_np.shape[0] + 1, figsize=(20, 4))
    
    # Plot each channel
    for i in range(latent_np.shape[0]):
        im = axes[i].imshow(latent_np[i], cmap='viridis')
        axes[i].set_title(f'Channel {i+1}')
        axes[i].axis('off')
        plt.colorbar(im, ax=axes[i])
    
    # Plot the mean of all channels
    mean_latent = np.mean(latent_np, axis=0)
    im = axes[-1].imshow(mean_latent, cmap='viridis')
    axes[-1].set_title('Mean of Channels')
    axes[-1].axis('off')
    plt.colorbar(im, ax=axes[-1])
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{filename_prefix}_latent_channels.png", dpi=300)
    plt.close()
    
    # 3. Save the latent and mean as numpy arrays
    np.save(f"{output_dir}/{filename_prefix}_latent.npy", latent_np)
    np.save(f"{output_dir}/{filename_prefix}_mean_latent.npy", mean_latent)
    
    # 4. Save individual PNG images of each latent channel
    for i in range(latent_np.shape[0]):
        plt.figure(figsize=(8, 8))
        plt.imshow(latent_np[i], cmap='viridis')
        plt.colorbar()
        plt.title(f'Latent Channel {i+1}')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{filename_prefix}_latent_channel_{i+1}.png", dpi=300)
        plt.close()
    
    # 5. Save mean latent as a separate image
    plt.figure(figsize=(8, 8))
    plt.imshow(mean_latent, cmap='viridis')
    plt.colorbar()
    plt.title(f'Mean of Latent Channels')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{filename_prefix}_mean_latent.png", dpi=300)
    plt.close()

# Run the script
if __name__ == "__main__":
    process_image()

Using device: cuda
Checking if files exist...
Loading model...
Creating model...
Loading checkpoint...
Available keys in checkpoint: ['epoch', 'autoencoder_state_dict', 'discriminator_state_dict', 'optimizer_g_state_dict', 'optimizer_d_state_dict', 'scheduler_g_state_dict', 'scheduler_d_state_dict', 'best_val_loss', 'config']
Found autoencoder_state_dict in checkpoint
Model loaded successfully
Loading image...
Image shape: torch.Size([1, 1, 256, 256]), dtype: torch.float32
Processing image with autoencoder...
Latent shape: torch.Size([1, 4, 64, 64])
Reconstructed shape: torch.Size([1, 1, 256, 256])


  with torch.cuda.amp.autocast(enabled=device.type=='cuda'):


All results saved to autoencoder_results
