# OOD-Experiments

In [None]:
import torch
import numpy as np
from idem_net_mnist import IdemNetMnist
from data_loader import load_MNIST
import matplotlib.pyplot as plt

In [None]:

run_id = "mnist20241113-115000"
epoch_num = "final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"
device

In [None]:
model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
# state_dict = state_dict["model_state_dict"]

In [None]:
model.load_state_dict(state_dict)
model.eval()

In [None]:
def plot_generation(inputs, outputs, num_images=5):
    """
    Plots input and output image pairs side by side.
    
    Parameters:
    - inputs: Batch of input images, expected shape (batch_size, height, width).
    - outputs: Batch of output images, expected shape (num_applications, batch_size, height, width).
    - num_images: Number of image pairs to display (default is 5).
    """
    # Limit the number of images to the smaller of num_images or batch size
    num_images = min(num_images, len(inputs), len(outputs))
    
    plt.figure(figsize=(8, num_images * 2))
    for i in range(num_images):
        # Plot input image
        plt.subplot(num_images, 2, 2 * i + 1)
        plt.imshow(inputs[i].squeeze(), cmap='gray')
        plt.title("Input")
        plt.axis('off')
        
        # Plot output image
        plt.subplot(num_images, 2, 2 * i + 2)
        plt.imshow(outputs[i].squeeze().detach().numpy(), cmap='gray')
        plt.title("Output")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
train_loader, test_loader = load_MNIST(batch_size=256)
images, labels = next(iter(train_loader))
output = model.forward(images)

plot_generation(images, output)