In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import FID_util
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# Set seed here
seed = 123459
torch.manual_seed(seed)

plot_path = "./plots/FID_diffusion_image/"

## Data Sets

In [None]:
CIFAR_train_loader = DataLoader(datasets.CIFAR10('./data', train=True, download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True)

CIFAR_train = FID_util.load_data_single(CIFAR_train_loader, 10000, (32, 32))
CIFAR_generated = FID_util.load_from_dir("./diffusion_images/cifar_generated/*.png")
print(f"Base FID CIFAR-10: {FID_util.compute_FID(CIFAR_train, CIFAR_generated).item()}")

In [None]:
CelebA_train = FID_util.load_from_dir("./data/celeba_hq_256/*.jpg", seed)
CelebA_generated = FID_util.load_from_dir("./diffusion_images/celebA_generated/*.png")
print(f"Base FID CelebA-HQ DDIM: {FID_util.compute_FID(CelebA_train, CelebA_generated, batch_num=125).item()}")
CelebA_ldm_generated = FID_util.load_from_dir("./diffusion_images/celebA_ldm_generated/*.png")
print(f"Base FID CelebA-HQ LDM: {FID_util.compute_FID(CelebA_train, CelebA_ldm_generated, batch_num=125).item()}")

## Gamma Correction

In [None]:
gamma_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", CIFAR_train, CIFAR_generated)
plt.savefig(plot_path + "g_CIFAR_gamma_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", CelebA_train, CelebA_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_gamma_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", CelebA_train, CelebA_ldm_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_ldm_gamma_fid_scores")

## Image Saturation

In [None]:
level_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(level_array, transforms.functional.adjust_saturation, "Saturation", CIFAR_train, CIFAR_generated)
plt.savefig(plot_path + "g_CIFAR_saturation_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_saturation, "Saturation", CelebA_train, CelebA_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_saturation_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_saturation, "Saturation", CelebA_train, CelebA_ldm_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_ldm_saturation_fid_scores")

## Image Sharpening

In [None]:
level_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(level_array, transforms.functional.adjust_sharpness, "Sharpness", CIFAR_train, CIFAR_generated)
plt.savefig(plot_path + "g_CIFAR_sharpness_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_sharpness, "Sharpness", CelebA_train, CelebA_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_sharpness_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_sharpness, "Sharpness", CelebA_train, CelebA_ldm_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_ldm_sharpness_fid_scores")

## Hue Adjust

In [None]:
hue_array = np.arange(-0.5, 0.5, 0.1)
FID_util.plot_FID(hue_array, transforms.functional.adjust_hue, "Hue", CIFAR_train, CIFAR_generated)
plt.savefig(plot_path + "g_CIFAR_hue_fid_scores")

In [None]:
FID_util.plot_FID(hue_array, transforms.functional.adjust_hue, "Hue", CelebA_train, CelebA_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_hue_fid_scores")

In [None]:
FID_util.plot_FID(hue_array, transforms.functional.adjust_hue, "Hue", CelebA_train, CelebA_ldm_generated, batch_num=125)
plt.savefig(plot_path + "g_CelebA_ldm_hue_fid_scores")

## Invert Image

In [None]:
print(f"FID of Inverted CIFAR: {FID_util.invert_FID(CIFAR_train, CIFAR_generated)}")

In [None]:
print(f"FID of Inverted CelebA DDIM: {FID_util.invert_FID(CIFAR_train, CelebA_generated, batch_num=125)}")

In [None]:
print(f"FID of Inverted CelebA LDM: {FID_util.invert_FID(CIFAR_train, CelebA_ldm_generated, batch_num=125))}")

## Rotations

In [None]:
angles = [90, 180, 270]
for angle in angles:
    print(f"------- Rotation of {angle} degrees -------")
    print(f"CIFAR: {FID_util.compute_transform_FID(transforms.functional.rotate, angle, CIFAR_train, CIFAR_generated)}")
    print(f"CelebA DDIM: {FID_util.compute_transform_FID(transforms.functional.rotate, angle, CelebA_train, CelebA_generated, batch_num=125)}")
    print(f"CelebA LDM: {FID_util.compute_transform_FID(transforms.functional.rotate, angle, CelebA_train, CelebA_ldm_generated, batch_num=125)}")

## Sectional Rotations

In [None]:
print(f"CIFAR: {FID_util.rotate_chunk((10, 10, 22, 22), CIFAR_train, CIFAR_generated)}")

In [None]:
print(f"CelebA DDIM: {FID_util.rotate_chunk((40, 90, 140, 160), CelebA_train, CelebA_generated, batch_num=125)}")

In [None]:
print(f"CelebA LDM: {FID_util.rotate_chunk((40, 90, 140, 160), CelebA_train, CelebA_ldm_generated, batch_num=125)}")

## Switching Segments

In [None]:
FID_util.swap_chunks((5, 10, 10, 20), (20, 15, 25, 25), CIFAR_train, CIFAR_test)

In [None]:
FID_util.swap_chunks((125, 100, 150, 130), (40, 100, 65, 130), CelebA_train, CelebA_generated, batch_num=125)

In [None]:
FID_util.swap_chunks((125, 100, 150, 130), (40, 100, 65, 130), CelebA_train, CelebA_ldm_generated, batch_num=125)

## Gaussian Blur

In [None]:
FID_util.generate_heatmap(CIFAR_train, CIFAR_generated, plot_path + "g_CIFAR_blurring_fid")

In [None]:
FID_util.generate_heatmap(CelebA_train, CelebA_generated, plot_path + "g_CelebA_blurring_fid")

In [None]:
FID_util.generate_heatmap(CelebA_train, CelebA_ldm_generated, plot_path + "g_CelebA_ldm_blurring_fid")

## Salt and Pepper Noise

In [None]:
percentages = np.arange(0, 0.1, 0.01)
noise_FID_results = [FID_util.noisify_FID(p, CIFAR_train, CIFAR_generated, seed=seed) for p in percentages]
plt.plot(percentages, noise_FID_results)
plt.xlabel("% of 'Salt and Pepper' Noise")
plt.ylabel("FID Score")
plt.savefig(plot_path + "g_CIFAR_Noise_fid_scores")

In [None]:
noise_FID_results = [FID_util.noisify_FID(p, CelebA_train, CelebA_generated, seed=seed, batch_num=125) for p in percentages]
plt.plot(percentages, noise_FID_results)
plt.xlabel("% of 'Salt and Pepper' Noise")
plt.ylabel("FID Score")
plt.savefig(plot_path + "g_CelebA_Noise_fid_scores")

In [None]:
noise_FID_results = [FID_util.noisify_FID(p, CelebA_train, CelebA_ldm_generated, seed=seed, batch_num=125) for p in percentages]
plt.plot(percentages, noise_FID_results)
plt.xlabel("% of 'Salt and Pepper' Noise")
plt.ylabel("FID Score")
plt.savefig(plot_path + "g_CelebA_ldm_Noise_fid_scores")