# OOD-Experiments

In [None]:
import torch
import numpy as np
from idem_net_mnist import IdemNetMnist
from idem_net_celeba import IdemNetCeleba
from data_loader import load_MNIST, load_CelebA
import matplotlib.pyplot as plt
from plot_utils import plot_generation

In [None]:

run_id = "mnist20241113-115000"
epoch_num = "final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"
device

In [None]:
if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()
else:
  model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
# state_dict = state_dict["model_state_dict"]


In [None]:
model.load_state_dict(state_dict)
# model.eval()

In [None]:
def plot_gray_generation(inputs, outputs, num_images=5):
    """
    Plots input and output image pairs side by side.
    
    Parameters:
    - inputs: Batch of input images, expected shape (batch_size, height, width).
    - outputs: Batch of output images, expected shape (num_applications, batch_size, height, width).
    - num_images: Number of image pairs to display (default is 5).
    """
    # Limit the number of images to the smaller of num_images or batch size
    num_images = min(num_images, len(inputs), len(outputs))
    
    plt.figure(figsize=(8, num_images * 2))
    for i in range(num_images):
        # Plot input image
        plt.subplot(num_images, 2, 2 * i + 1)
        plt.imshow(inputs[i].squeeze(), cmap="gray")
        plt.title("Input")
        plt.axis('off')
        
        # Plot output image
        plt.subplot(num_images, 2, 2 * i + 2)
        plt.imshow(outputs[i].squeeze().detach().numpy(), cmap="gray")
        plt.title("Output")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_rgb_generation(inputs, outputs, num_images=5):
    """
    Plots input and output image pairs side by side.
    
    Parameters:
    - inputs: Batch of input images, expected shape (batch_size, height, width).
    - outputs: Batch of output images, expected shape (num_applications, batch_size, height, width).
    - num_images: Number of image pairs to display (default is 5).
    """
    # Limit the number of images to the smaller of num_images or batch size
    num_images = min(num_images, len(inputs), len(outputs))
    
    plt.figure(figsize=(8, num_images * 2))
    for i in range(num_images):
        # Plot input image
        plt.subplot(num_images, 2, 2 * i + 1)
        plt.imshow(inputs[i].squeeze().permute(1,2,0))
        plt.title("Input")
        plt.axis('off')
        
        # Plot output image
        plt.subplot(num_images, 2, 2 * i + 2)
        plt.imshow(outputs[i].squeeze().permute(1,2,0).detach().numpy())
        plt.title("Output")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
if "celeba" in run_id:
  train_loader, test_loader = load_CelebA(batch_size=9) #load_MNIST(batch_size=9)
else:
  train_loader, test_loader = load_MNIST(batch_size=256)
images, labels = next(iter(train_loader))
# with torch.no_grad():
#   output = model(images)


In [None]:
from torch import nn

with torch.no_grad():
    intermediate_activations = {}
    def save_activation(name):
        def hook(module, input, output):
            intermediate_activations[name] = output
        return hook

    # Attach hooks to BatchNorm layers
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            module.register_forward_hook(save_activation(name))

    # Perform forward pass
    output = model(images)

# Compare batch means and running means for each BatchNorm layer
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get the activations and compute batch mean
        activations = intermediate_activations[name]
        batch_mean = activations.mean(dim=(0, 2, 3))  # Mean over N, H, W
        running_mean = module.running_mean
        print(f"{name} batch: {batch_mean.mean()}")
        print(f"{name} running: {running_mean.mean()}")
        print(f"{name} diff: {(batch_mean - running_mean).mean()}")

In [None]:

plot_generation(images, model, num_applications=2)

In [None]:
noise = torch.randn((1, 1,28,28))

plot_generation(noise, model, num_applications=2)