In [1]:
import torch
import torchvision.datasets as datasets
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import FID_util
import numpy as np
%matplotlib inline

# Set seed here
seed = 123459
torch.manual_seed(seed)

plot_path = "./plots/FID_real_image/"

<torch._C.Generator at 0x14a757fbf530>

## Data Sets

In [None]:
# MNIST
MNIST_train_loader = FID_util.DataLoader(datasets.MNIST('./data', train=True, download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);

MNIST_test_loader =  FID_util.DataLoader(datasets.MNIST('./data', train=False, download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);
MNIST_train, MNIST_test = FID_util.load_data(MNIST_train_loader, MNIST_test_loader, 10000, (28, 28))

# Base FID
print(f"Base FID MNIST: {FID_util.compute_FID(MNIST_train, MNIST_test).item()}")

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /home/shaoqingf/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|████████████████████████████████████████████████████████████████████| 91.2M/91.2M [00:04<00:00, 20.3MB/s]


In [None]:
# CIFAR 100
CIFAR_train_loader = FID_util.DataLoader(datasets.CIFAR100('./data', train=True, download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);

CIFAR_test_loader =  FID_util.DataLoader(datasets.CIFAR100('./data', train=False, download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);
CIFAR_train, CIFAR_test = FID_util.load_data(CIFAR_train_loader, CIFAR_test_loader, 10000, (32, 32))

# Base FID
print(f"Base FID CIFAR: {FID_util.compute_FID(CIFAR_train, CIFAR_test).item()}")

In [None]:
# CelebA
CelebA_train_loader = DataLoader(datasets.CelebA('./data', split="train", download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);

CelebA_test_loader =  DataLoader(datasets.CelebA('./data', split="test", download=True,
                             transform=transforms.Compose([
                               transforms.PILToTensor()])), batch_size=10000, shuffle=True);
CelebA_train, CelebA_test = FID_util.load_data(CelebA_train_loader, CelebA_test_loader, 10000, (218, 178))

# Base FID
print(f"Base FID CelebA: {FID_util.compute_FID(CelebA_train, CelebA_test, batch_num=125).item()}")

## Gamma Correction

In [None]:
gamma_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", MNIST_train, MNIST_test)
plt.savefig(plot_path + "MNIST_gamma_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", CIFAR_train, CIFAR_test)
plt.savefig(plot_path+ "CIFAR_gamma_fid_scores")

In [None]:
FID_util.plot_FID(gamma_array, transforms.functional.adjust_gamma, "Gamma", CelebA_train, CelebA_test, batch_num=125)
plt.savefig(plot_path+ "CelebA_gamma_fid_scores")

## Image Saturation

In [None]:
level_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(level_array, transforms.functional.adjust_saturation, "Saturation", MNIST_train, MNIST_test)
plt.savefig(plot_path + "MNIST_saturation_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_saturation, "Saturation", CIFAR_train, CIFAR_test)
plt.savefig(plot_path + "CIFAR_saturation_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_saturation, "Saturation", CelebA_train, CelebA_test, batch_num=125)
plt.savefig(plot_path + "CelebA_saturation_fid_scores")

## Image Sharpening

In [None]:
level_array = np.arange(0.2, 3.0, 0.2)
FID_util.plot_FID(level_array, transforms.functional.adjust_sharpness, "Sharpness", MNIST_train, MNIST_test)
plt.savefig(plot_path + "MNIST_sharpness_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_sharpness, "Sharpness", CIFAR_train, CIFAR_test)
plt.savefig(plot_path + "CIFAR_sharpness_fid_scores")

In [None]:
FID_util.plot_FID(level_array, transforms.functional.adjust_sharpness, "Sharpness", CelebA_train, CelebA_test, batch_num=125)
plt.savefig(plot_path + "CelebA_sharpness_fid_scores")

## Hue Adjust

In [None]:
hue_array = np.arange(-0.5, 0.5, 0.1)
FID_util.plot_FID(hue_array, transforms.functional.adjust_hue, "Hue", CIFAR_train, CIFAR_test)
plt.savefig(plot_path + "CIFAR_hue_fid_scores")

In [None]:
FID_util.plot_FID(hue_array, transforms.functional.adjust_hue, "Hue", CelebA_train, CelebA_test, batch_num=125)
plt.savefig(plot_path + "CelebA_hue_fid_scores")

## Invert Image

In [None]:
print(f"FID of Inverted MNIST: {FID_util.invert_FID(MNIST_train, MNIST_test)}")

In [None]:
print(f"FID of Inverted CIFAR: {FID_util.invert_FID(CIFAR_train, CIFAR_test)}")

In [None]:
print(f"FID of Inverted CelebA: {FID_util.invert_FID(CelebA_train, CelebA_test, batch_num=125)}")

## Gaussian Blur

In [None]:
FID_util.generate_heatmap(MNIST_train, MNIST_test, plot_path + "MNIST_blurring_fid")

In [None]:
FID_util.generate_heatmap(CIFAR_train, CIFAR_test, plot_path + "CIFAR_blurring_fid")

In [None]:
FID_util.generate_heatmap(CelebA_train, CelebA_test, plot_path + "CelebA_blurring_fid", batch_num=125)

## Salt and Pepper Noise

In [None]:
percentages = np.arange(0, 0.1, 0.01)
noise_FID_results = [FID_util.noisify_FID(p, CIFAR_train, CIFAR_test) for p in percentages]
plt.plot(percentages, noise_FID_results)
plt.xlabel("% of 'Salt and Pepper' Noise")
plt.ylabel("FID Score")
plt.savefig(plot_path + "CIFAR_Noise_fid_scores")

In [None]:
noise_FID_results = [FID_util.noisify_FID(p, CelebA_train, CelebA_test, batch_num=125) for p in percentages]
plt.plot(percentages, noise_FID_results)
plt.xlabel("% of 'Salt and Pepper' Noise")
plt.ylabel("FID Score")
plt.savefig(plot_path + "CelebA_Noise_fid_scores")

## Image Rotation

In [None]:
angles = [90, 180, 270]
for angle in angles:
    print(f"------- Rotation of {angle} degrees -------")
    print(f"MNIST: {FID_util.compute_transform_FID(transforms.functional.rotate, angle, MNIST_train, MNIST_test)}")
    print(f"CIFAR: {FID_util.compute_transform_FID(transforms.functional.rotate, angle, CIFAR_train, CIFAR_test)}")

In [None]:
# CelebA is rectangle, so only 180 degree is possible
print(f"CelebA 180 degrees: {FID_util.compute_transform_FID(transforms.functional.rotate, 180, CelebA_train, CelebA_test, batch_num=125)}")

## Sectional Rotations

In [None]:
print(f"MNIST: {FID_util.rotate_chunk((9, 9, 19, 19), MNIST_train, MNIST_test)}")

In [None]:
print(f"CIFAR: {FID_util.rotate_chunk((10, 10, 22, 22), CIFAR_train, CIFAR_test)}")

In [None]:
print(f"CelebA: {FID_util.rotate_chunk((40, 90, 140, 160), CelebA_train, CelebA_test, batch_num=125)}")

## Switching Segments

In [None]:
FID_util.swap_chunks((5, 10, 10, 20), (20, 15, 25, 25), MNIST_train, MNIST_test)

In [None]:
FID_util.swap_chunks((5, 10, 10, 20), (20, 15, 25, 25), CIFAR_train, CIFAR_test)

In [None]:
FID_util.swap_chunks((125, 100, 150, 130), (40, 100, 65, 130), CelebA_train, CelebA_test, batch_num=125)