In [None]:
import torch
from matplotlib import pyplot as plt

from perceptual_loss import PerceptualLoss
from idem_net_celeba import IdemNetCeleba
from data_loader import load_CelebA

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device('cuda')
elif torch.backends.mps.is_available():
  device = torch.device('mps')
else:
  device = torch.device('cpu')

dataloader, _ = load_CelebA()

In [None]:
run_id = "celeba20241201-151237"
epoch_num = "_final.pth"

checkpoint_path = f"checkpoints/{run_id}/{epoch_num}"

if "celeba" in run_id:
  model = IdemNetCeleba(3) # IdemNetMnist()

state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device)
model.load_state_dict(state_dict)

In [None]:
test_img, _ = next(iter(dataloader))

test_noise = torch.randn_like(test_img)

with torch.no_grad():
  model_img = model(test_img)
  model_noise = model(test_noise)
plt.imshow(model_img.squeeze().permute(1,2,0))
plt.imshow(model_noise.squeeze().permute(1,2,0))
plt.show()

In [None]:
imnet_model = PerceptualLoss(device=device)

celeba_model = PerceptualLoss('vgg16_celeba', layers=[3], device=device)

squeeze_imnet_model = PerceptualLoss('squeeze', device=device)

print(squeeze_imnet_model)

In [None]:

print("imnet noise img ", imnet_model(test_noise, test_img))

print("celeba noise img", celeba_model(test_noise, test_img))

print("squeeze noise img", squeeze_imnet_model(test_noise, test_img))

print("imnet idem(img) img", imnet_model(model_img, test_img))

print("celeba idem(img) img", celeba_model(model_img, test_img))

print("squeeze noise img", squeeze_imnet_model(model_img, test_img))

# print("imnet idem(noise) img", imnet_model(model_noise, test_img, 1))

# print("celeba idem(noise) img", celeba_model(model_noise, test_img, 1))