In [None]:
import torch
import numpy as np
from idem_net_mnist import IdemNetMnist
from idem_net_celeba import IdemNetCeleba
from data_loader import load_MNIST, load_CelebA
from torchvision.transforms import GaussianBlur, functional as TF
from plot_utils import *
import matplotlib.pyplot as plt

In [None]:
run_id = "mnist20241122-191540" # "celeba20241113-154812"
epoch_num = "final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"
device

In [None]:
if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()
else:
  model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
model.load_state_dict(state_dict)
# model.eval()

In [None]:
if "celeba" in run_id:
  train_loader, test_loader = load_CelebA(batch_size=9) #load_MNIST(batch_size=9)
else:
  train_loader, test_loader = load_MNIST(batch_size=9)


In [None]:
# Assumption: in the paper it says noise is n=(0, 0.15). We assume this refers to std**2 being 0.15
def gaussian_noise(images, mean=0.0, std=0.15):
  noise = torch.randn_like(images) * std + mean
  noise_img = images + noise
  return noise_img

def rotation(images, angle=90):
  return torch.stack([TF.rotate(img, angle) for img in images])

# can only be aplied ot celeba
def grayscale(images):
  return images.mean(dim=1, keepdim=True).repeat(1,3,1,1)

# can only be applied ot celeba
def sketch(images):
  kernel_size = 21
  sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8
  gaussian_blur = GaussianBlur(kernel_size=kernel_size, sigma=sigma)

  # TODO they mention something about adding 1, I dont understand why
  gray_images = grayscale(images)

  sketch = ((gray_images) / (gaussian_blur(gray_images) + 1e-10)) #- 1

  return sketch



In [None]:

images, labels = next(iter(train_loader))


In [None]:

plot_generation(images, model, 5, 2)