In [1]:
import torch
import PIL.Image as Image
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os


from fid import calculate_statistics_for_given_paths, calculate_frechet_distance
from fastai.vision.all import Path
from tqdm import tqdm

import os
os.environ['CUDA_VISIBLE_DEVICES']='2'

In [2]:
from inception import InceptionV3, InceptionV3_sehun

In [3]:
class png_dataset(torch.utils.data.Dataset):
    def __init__(self, dirs, transforms = None):
        '''
        file : paths for images
        Note that in inception.py the inputs are transformed [0, 1] to [-1, 1]
        So, here, input images should be in [0, 1]
        '''
        self.dirs = dirs
        self.transforms = transforms

    def __len__(self):
        return len(self.dirs)

    def __getitem__(self, i):
        dir = self.dirs[i]
        img = Image.open(dir)#.convert('RGB') <- 이거 시간 나름 걸림
        if self.transforms is not None:
            img = self.transforms(img)
        return img



transformations = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [0.3831, 0.2659, 0.1896], std = [0.2915, 0.2107, 0.1708]),
])

valid_list = list(map(str, list(Path('/home/DB/SuGAr/RF/val').rglob('*png'))))
print(len(valid_list))
train_list = list(map(str, list(Path('/home/DB/SuGAr/RF/trn').rglob('*png'))))
print(len(train_list))



train_dataset = png_dataset(dirs = train_list, transforms = transformations)
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = 1024,
    pin_memory = True,
    num_workers = 4,
    shuffle = False)

valid_dataset = png_dataset(dirs = valid_list, transforms = transformations)
valid_loader = torch.utils.data.DataLoader(
    dataset = valid_dataset,
    batch_size = 1024,
    pin_memory = True,
    num_workers = 4,
    shuffle = False)

8870
70962


In [4]:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3([block_idx])

for i in model.parameters():
    pa = i
    break
print(pa[0,0,0])


dataloader_dict = {
    'RF_ref_train' : train_loader,
    'RF_ref_valid' : valid_loader,
}


batch_size = 1024
device = 'cuda'
dims = 2048
calculate_statistics_for_given_paths(
    dataloader_dict, batch_size, device, dims)


RF_rdf_val_m = np.load('./np_saves/RF_ref_valid_m.npy')
RF_rdf_val_s = np.load('./np_saves/RF_ref_valid_s.npy')
RF_rdf_train_m = np.load('./np_saves/RF_ref_train_m.npy')
RF_rdf_train_s = np.load('./np_saves/RF_ref_train_s.npy')


calculate_frechet_distance(
    RF_rdf_val_m, RF_rdf_val_s, RF_rdf_train_m, RF_rdf_train_s, eps=1e-6)

tensor([ 0.0126, -0.0016,  0.0909])
calculating statistics for RF_ref_train
total images :  70962


100%|██████████| 70/70 [02:50<00:00,  2.44s/it]


(70962, 2048)
calculating statistics for RF_ref_valid
total images :  8870


100%|██████████| 9/9 [00:26<00:00,  2.90s/it]


(8870, 2048)
saving statistics done for 2


0.7177827708652273

In [None]:
block_idx = InceptionV3_sehun.BLOCK_INDEX_BY_DIM[2048]
model = InceptionV3_sehun([block_idx])

for i in model.parameters():
    pa = i
    break
print(pa[0,0,0])
    
dataloader_dict = {
    'sehun_RF_ref_train' : train_loader,
    'sehun_RF_ref_valid' : valid_loader,
}


batch_size = 1024
device = 'cuda'
dims = 2048
calculate_statistics_for_given_paths(
    dataloader_dict, batch_size, device, dims)


RF_rdf_val_m = np.load('./np_saves/sehun_RF_ref_valid_m.npy')
RF_rdf_val_s = np.load('./np_saves/sehun_RF_ref_valid_s.npy')
RF_rdf_train_m = np.load('./np_saves/sehun_RF_ref_train_m.npy')
RF_rdf_train_s = np.load('./np_saves/sehun_RF_ref_train_s.npy')


calculate_frechet_distance(
    RF_rdf_val_m, RF_rdf_val_s, RF_rdf_train_m, RF_rdf_train_s, eps=1e-6)

tensor([-0.2103, -0.3441, -0.0344])
calculating statistics for sehun_RF_ref_train
total images :  70962


 23%|██▎       | 16/70 [00:40<01:20,  1.48s/it]

In [None]:
# by sehun
mu_val, std_val = np.load('../FID/mu_val.npy'), np.load('../FID/std_val.npy')
mu_trn, std_trn = np.load('../FID/mu_trn.npy'), np.load('../FID/std_trn.npy')
calculate_frechet_distance(mu_val, std_val, mu_trn, std_trn)