In [1]:
!git clone -b dev https://github.com/louisdo/These-Image-Do-Not-Exist.git

Cloning into 'These-Image-Do-Not-Exist'...
remote: Enumerating objects: 83, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 83 (delta 2), reused 7 (delta 2), pack-reused 75[K
Unpacking objects: 100% (83/83), done.


In [1]:
import sys, json, os, torch, torchvision, random
sys.path.append("./These-Image-Do-Not-Exist")
sys.path.append("./These-Image-Do-Not-Exist/src")
from eval import FIDEval
from dataloader import ImageDataset, get_loader
from model import Generator
from tqdm import tqdm

In [2]:
def create_folder(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [4]:
!unzip drive/MyDrive/HUST_related/Intro2DS/capstone/data/image_folder/image_folder.zip -d ./image_folder/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./image_folder/data/image_folder/8269654220.jpg  
  inflating: ./image_folder/data/image_folder/8269762253.jpg  
  inflating: ./image_folder/data/image_folder/8269977834.jpg  
  inflating: ./image_folder/data/image_folder/8270033228.jpg  
  inflating: ./image_folder/data/image_folder/8270051568.jpg  
  inflating: ./image_folder/data/image_folder/8270053197.jpg  
  inflating: ./image_folder/data/image_folder/8270499259.jpg  
  inflating: ./image_folder/data/image_folder/8270734569.jpg  
  inflating: ./image_folder/data/image_folder/8270887.jpg  
  inflating: ./image_folder/data/image_folder/8271663631.jpg  
  inflating: ./image_folder/data/image_folder/8271800620.jpg  
  inflating: ./image_folder/data/image_folder/8271802136.jpg  
  inflating: ./image_folder/data/image_folder/8272218752.jpg  
  inflating: ./image_folder/data/image_folder/8272500191.jpg  
  inflating: ./image_folder/data/image_folder/827271791.

In [3]:
fid_eval = FIDEval()

# Real image dataloader

In [4]:
CATEGORY = "0"
IMAGE_FOLDER = "./image_folder/data/image_folder"
REAL_IMAGE_INFO_FILE = "drive/MyDrive/HUST_related/Intro2DS/capstone/data/image_info_data_3categories.json"
OUTLIER_IDS_FILE = "drive/MyDrive/HUST_related/Intro2DS/capstone/data/indices2remove_3categories.json"

REAL_IMAGE_DATASET_CONFIG = {
    "image_size": 299,
    "mean": [0.5,0.5,0.5],
    "std": [0.5,0.5,0.5]
}
REAL_IMAGE_LOADER_CONFIG = {
    "batch_size": 16,
    "num_workers": 0,
    "shuffle": False
}

In [5]:
with open(REAL_IMAGE_INFO_FILE) as f:
    real_image_info = json.load(f)

with open(OUTLIER_IDS_FILE) as f:
    outlier_ids = json.load(f)

In [6]:
real_image_info_category = [real_image_info[CATEGORY][i] for i in range(len(real_image_info[CATEGORY])) if i not in outlier_ids[CATEGORY]]

real_image_dataset = ImageDataset(data = real_image_info_category, data_folder = IMAGE_FOLDER, config = REAL_IMAGE_DATASET_CONFIG)
real_image_loader = get_loader(real_image_dataset, REAL_IMAGE_LOADER_CONFIG)

# Fake image dataloader

## Set up the generator and generate images

In [7]:
NUM_IMAGES_TO_GENERATE = 16000
FAKE_IMAGE_FOLDER = f"./fake_image_folder_{CATEGORY}"
CHECKPOINT = "drive/MyDrive/HUST_related/Intro2DS/capstone/ckpt/16122021/ckpt_epoch40.pth"

GENERATOR_CONFIG = {
    "number_channel": 3,
    "image_size": 64,
    "d_hidden": 128,
    "num_classes": 3
}

device = fid_eval.device
generator = Generator(GENERATOR_CONFIG)
generator.load_state_dict(torch.load(CHECKPOINT)["generator"])
generator = generator.to(device)
generator.eval()
print("Done loading generator")

Done loading generator


In [10]:
create_folder(FAKE_IMAGE_FOLDER)

In [11]:
label = torch.nn.functional.one_hot(torch.tensor(int(CATEGORY)), 
                                    GENERATOR_CONFIG["num_classes"]).unsqueeze(0).repeat(1,1).float().to(device)

for idx in tqdm(range(NUM_IMAGES_TO_GENERATE), desc = "Generating images with generator"):
    noise = torch.randn((1, GENERATOR_CONFIG["d_hidden"])).float().to(device)
    with torch.no_grad():
        _fake_images = generator(noise, label)

    fake_images = torchvision.utils.make_grid(_fake_images, padding = 0, normalize = True)
    torchvision.utils.save_image(fake_images, os.path.join(FAKE_IMAGE_FOLDER, f"{idx}.jpg"))

Generating images with generator: 100%|██████████| 16000/16000 [00:28<00:00, 554.39it/s]


## Set up fake image loader

In [8]:
FAKE_IMAGE_DATASET_CONFIG = REAL_IMAGE_DATASET_CONFIG
FAKE_IMAGE_LOADER_CONFIG = REAL_IMAGE_LOADER_CONFIG

fake_image_info = [{"id": str(idx)} for idx in range(NUM_IMAGES_TO_GENERATE)]

In [13]:
fake_image_dataset = ImageDataset(data = fake_image_info, data_folder = FAKE_IMAGE_FOLDER, config = FAKE_IMAGE_DATASET_CONFIG)
fake_image_loader = get_loader(fake_image_dataset, FAKE_IMAGE_LOADER_CONFIG)

In [14]:
fid_eval.fid_score(fake_image_loader, real_image_loader)

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Inferring features for real images: 100%|██████████| 942/942 [02:03<00:00,  7.64it/s]
Inferring features for fake images: 100%|██████████| 1000/1000 [01:51<00:00,  8.93it/s]


28.80945816111523

In [None]:
# FID for cDCGAN

# category 1 (castle): 31.18888495946385
# category 2 (landscape): 29.882495752254158
# category 3 (seascape): 28.80945816111523

# Random image model

In [9]:
RANDOM_IMAGE_LOADER_CONFIG = REAL_IMAGE_LOADER_CONFIG
RANDOM_IMAGE_DATASET_CONFIG = FAKE_IMAGE_DATASET_CONFIG

ALL_CATEGORIES = {"0", "1", "2"}
random_image_info = []
for cat in ALL_CATEGORIES:
    if cat == CATEGORY: continue
    random_image_info.extend(real_image_info[cat])
random_image_info = random.sample(random_image_info, NUM_IMAGES_TO_GENERATE)

random_image_dataset = ImageDataset(data = random_image_info, data_folder = IMAGE_FOLDER, config = RANDOM_IMAGE_DATASET_CONFIG)
random_image_loader = get_loader(random_image_dataset, RANDOM_IMAGE_LOADER_CONFIG)

In [10]:
fid_eval.fid_score(random_image_loader, real_image_loader)

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Inferring features for real images: 100%|██████████| 962/962 [02:06<00:00,  7.62it/s]
Inferring features for fake images: 100%|██████████| 1000/1000 [02:10<00:00,  7.67it/s]


30.98752333453183

In [None]:
# FID for random image model

# category 1: 31.63962420116127
# category 2: 30.98752333453183
# category 3: 35.06780713772184