In [5]:
import torch
import numpy as np
from idem_net_mnist import IdemNetMnist
from idem_net_celeba import IdemNetCeleba
from data_loader import load_MNIST, load_CelebA
import matplotlib.pyplot as plt
from plot_utils import plot_generation
from ignite.metrics import FID

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:

run_id = "mnist20241113-115000"
epoch_num = "final.pth"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"
device

device(type='cpu')

In [3]:
if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()
else:
  model = IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
# state_dict = state_dict["model_state_dict"]


In [4]:
model.load_state_dict(state_dict)
# model.eval()

<All keys matched successfully>

In [10]:
import torch.nn.functional as F

# Load the FID metric and feature extractor (InceptionV3)
fid_metric = FID(device=device)

# Get a batch of real images for FID score calculation
data_loader, test_loader = load_MNIST(batch_size=256)
test_imgs, _ = next(iter(data_loader))
test_imgs = test_imgs.to(device)

# Generate fake images using the trained model
z_gen = torch.randn_like(test_imgs)  # Batch of noise to generate images from
with torch.no_grad():
    fake_images = model(z_gen)

# Normalize to [0, 1]
real_images = (test_imgs - test_imgs.min()) / (test_imgs.max() - test_imgs.min())
fake_images = (fake_images - fake_images.min()) / (fake_images.max() - fake_images.min())

# Convert single-channel images to three channels
real_images = real_images.repeat(1, 3, 1, 1)
fake_images = fake_images.repeat(1, 3, 1, 1)

# Resize images to 299x299 for InceptionV3
real_images = F.interpolate(real_images, size=(299, 299), mode="bilinear", align_corners=False)
fake_images = F.interpolate(fake_images, size=(299, 299), mode="bilinear", align_corners=False)

# Update FID metric
fid_metric.reset()
fid_metric.update((real_images, fake_images))
fid_score = fid_metric.compute()

print(f"FID Score on the last images: {fid_score}")


: 