In [None]:
import sys
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

if '..' not in sys.path:
    sys.path.append('..')
    
from src.lit_models.lit_masking_model import LitMaskingModel
from src.datasets.cifar10 import CIFAR10Data


In [None]:
CHECKPOINT_PATH = "../weights/masker-cifar10.ckpt"
MAE_CHECKPOINT_PATH = "../weights/mae-cifar10.ckpt"

In [None]:
masking_model_state_dict = torch.load(CHECKPOINT_PATH)
mae_model_state_dict = torch.load(MAE_CHECKPOINT_PATH)

In [None]:
hyperparams = masking_model_state_dict['hyper_parameters']
hyperparams = {k: v for k, v in hyperparams.items() if not k.startswith('_')}
hyperparams['mae_checkpoint_path'] = MAE_CHECKPOINT_PATH
state_dict = masking_model_state_dict['state_dict']

new_state_dict = {}
for key, value in state_dict.items():
    # remove any instance of _orig_mod.
    if '_orig_mod.' in key:
        key = key.replace('_orig_mod.', '')
    new_state_dict[key] = value


model = LitMaskingModel(**hyperparams)
model.load_state_dict(new_state_dict)
model.eval()
model.cuda()
print("Model loaded successfully.")


In [None]:
# Load the dataset
NUM_ROWS = 16
print("Loading CIFAR10 dataset...")
datamodule = CIFAR10Data(data_dir="../data", batch_size=NUM_ROWS)
datamodule.prepare_data()
datamodule.setup(stage="fit")

# Get a batch of images
val_loader = datamodule.val_dataloader()
images, _ = next(iter(val_loader))
images = images.to(model.device)

print(f"Loaded {images.shape[0]} images for visualization.")


In [None]:
# Indices for the "cherry-picked" visualization (the "good" results)
cherry_picked_indices = [1, 3, 6, 9, 11, 13, 15] # Example indices, you can change this

# Take remaining indices
all_indices = list(range(NUM_ROWS))
remaining_indices = [i for i in all_indices if i not in cherry_picked_indices]

print(f"Cherry-picked indices: {cherry_picked_indices}")
print(f"Remaining indices: {remaining_indices}")


In [None]:
# Run inference to get model outputs on all 16 images
print(f"Running inference with tau: {model.tau}")
with torch.no_grad():
    x_recon, mask_probs = model(images)
    x_recon = torch.clamp(x_recon, 0, 1)

print("Inference complete.")
print(f"  - Reconstructions tensor shape: {x_recon.shape}")
print(f"  - Mask probabilities tensor shape: {mask_probs.shape}")


In [None]:
    # Upsample masks from patch resolution to image resolution
num_patches = mask_probs.shape[1]
num_patches_per_side = int(np.sqrt(num_patches))
image_size = model.hparams.image_size

# Reshape from (B, N) to (B, H', W') where H' and W' are patch dimensions
mask_spatial = mask_probs.view(-1, num_patches_per_side, num_patches_per_side)
# Add a channel dimension: (B, 1, H', W')
mask_spatial = mask_spatial.unsqueeze(1)
# Upsample to image size (B, 1, H, W)
mask_image = F.interpolate(
    mask_spatial,
    size=(image_size, image_size),
    mode='nearest'
)

# Move tensors to CPU and convert to numpy
images_np = images.cpu().numpy().astype(np.float32)
x_recon_np = x_recon.cpu().numpy().astype(np.float32)
mask_image_np = mask_image.cpu().numpy().astype(np.float32)

print("Data prepared for visualization.")


In [None]:
def create_and_save_visualization(indices, images_np, x_recon_np, mask_image_np, filename):
    """
    Generates and saves a visualization for a given set of indices.

    Args:
        indices (list): List of indices to include in the visualization.
        images_np (np.ndarray): Numpy array of all original images.
        x_recon_np (np.ndarray): Numpy array of all reconstructed images.
        mask_image_np (np.ndarray): Numpy array of all mask images.
        filename (str): The filename to save the plot as.
    """
    num_rows = len(indices)
    if num_rows == 0:
        print(f"No indices to visualize for {filename}.")
        return

    fig, axes = plt.subplots(num_rows, 4, figsize=(12, 3 * num_rows))
    fig.suptitle(f"Model Visualization - {filename.replace('.png', '').replace('_', ' ').title()}", fontsize=16)

    titles = ["Original Image", "Generated Mask", "Mask Overlay", "Composite Image"]
    
    # Handle single row case
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, idx in enumerate(indices):
        # 1. Original Image
        ax = axes[i, 0]
        ax.imshow(np.transpose(images_np[idx], (1, 2, 0)))
        ax.set_title(titles[0] if i == 0 else "")
        ax.axis('off')

        # 2. Generated Mask (white = transmitted, black = masked)
        ax = axes[i, 1]
        ax.imshow(mask_image_np[idx, 0], cmap='gray')
        ax.set_title(titles[1] if i == 0 else "")
        ax.axis('off')

        # 3. Mask Overlay (red overlay on masked regions)
        ax = axes[i, 2]
        original_img_permuted = np.transpose(images_np[idx], (1, 2, 0))
        masking_prob = 1.0 - mask_image_np[idx, 0]
        red_overlay = np.zeros_like(original_img_permuted)
        red_overlay[..., 0] = 1.0
        alpha = 0.6 * masking_prob[..., np.newaxis]
        overlayed_image = (1 - alpha) * original_img_permuted + alpha * red_overlay
        overlayed_image = np.clip(overlayed_image, 0, 1)
        ax.imshow(overlayed_image)
        ax.set_title(titles[2] if i == 0 else "")
        ax.axis('off')

        # 4. Composite Image (what receiver sees)
        ax = axes[i, 3]
        original_img = np.transpose(images_np[idx], (1, 2, 0))
        recon_img = np.transpose(x_recon_np[idx], (1, 2, 0))
        transmission_prob = mask_image_np[idx, 0][..., np.newaxis]
        composite_img = (transmission_prob * original_img + (1 - transmission_prob) * recon_img)
        composite_img = np.clip(composite_img, 0, 1)
        ax.imshow(composite_img)
        ax.set_title(titles[3] if i == 0 else "")
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    # Save the figure
    save_path = os.path.join("..", "assets", filename)
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Visualization saved to {save_path}")
    
    plt.show()

# --- Create and save both visualizations ---

# Create and save the visualization for cherry-picked images ("good" results)
print("\nCreating cherry-picked visualization...")
create_and_save_visualization(
    cherry_picked_indices,
    images_np,
    x_recon_np,
    mask_image_np,
    "good_visualization.png"
)

# Create and save the visualization for the remaining images ("bad" results)
print("\nCreating remaining visualization...")
create_and_save_visualization(
    remaining_indices,
    images_np,
    x_recon_np,
    mask_image_np,
    "bad_visualization.png"
)
