In [None]:
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
from gan_model import Discriminator, Generator
from fid_score import *
from inception import *
import pandas as pd
import torchvision.utils as vutils
from fid import calculate_fid_given_paths

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED']=str(seed)

In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
dev = torch.device(dev)

In [None]:
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = datasets.CIFAR10(root='./datasets/cifar/', train=False, download=True, transform=trans_cifar)
dataloader_test = torch.utils.data.DataLoader(dataset, shuffle = True,batch_size=10000)


In [None]:
# create test images of the chosen classes
for img in dataloader_test:
    x, y = img
    chosen_imgs = []
    for xx,yy in zip(x,y):
        if(yy in [9,1,5,3]):
            chosen_imgs.append(xx)
    test_imgs=chosen_imgs
    # test_imgs=img[0].to(dev)
test_imgs = torch.stack(test_imgs)

In [None]:
epochs = 5
gen_count_to_test = 9
file_loc = [
'runs/2W_MIN_LOSS/models/G_epoch_199',
'runs/2W_MAX_LOSS/models/G_epoch_199',

'runs/2W_MAX_LOSS_OVERRIDE/models/G_epoch_199',
'runs/2W_MIN_LOSS_OVERRIDE_2:1/models/G_epoch_149',

'runs/2W_WEIGHTED_MOST/models/G_epoch_199',
'runs/2W_WEIGHTED_LEAST/models/G_epoch_199',

'runs/2W_WEIGHTED_MOST_LR/models/G_epoch_199',

'runs/2W_WEIGHTED_MOST_3:1/models/G_epoch_199',
'runs/2W_WEIGHTED_LEAST_3:1/models/G_epoch_199',
]
assert gen_count_to_test == len(file_loc)

NOISE_DIM = 128
FID_BATCH_SIZE = 1200
# fic_model = InceptionV3().to(dev)
# generators = {}
# for i in range(gen_count_to_test):
#     key = f'gen{i}'
#     gen = Generator().to(dev)
#     gen.load_state_dict(torch.load(file_loc[i]))
#     # gen.eval()  # set the model to evaluation mode
#     generators[key] = gen

In [None]:
fid_avg = []
for epoch in range(epochs):
    fid_z = torch.randn(FID_BATCH_SIZE, NOISE_DIM, 1,1).to(dev)
    random_start = np.random.randint(len(test_imgs)-FID_BATCH_SIZE)
    real_imgs = test_imgs[random_start:random_start+FID_BATCH_SIZE]

    # Save real images once per epoch
    real_img_dir = f'test/real_imgs/epoch_{epoch}/'
    os.makedirs(real_img_dir, exist_ok=True)
    for j, img in enumerate(real_imgs):
        vutils.save_image(img,  f'{real_img_dir}real_img_{j}.png')

    for i, gen_path in enumerate(file_loc):
        # Load generator
        gen = Generator().to(dev)
        gen.load_state_dict(torch.load(gen_path))
        if epoch == 0:
            fid_avg.append([])
        gen_imgs = gen(fid_z.detach())

        # Save generated images
        gen_name = gen_path.split('/')[1]  # Use the second name in the breadcrumb
        gen_img_dir = f'test/generator_imgs/{gen_name}/epoch_{epoch}/'
        os.makedirs(gen_img_dir, exist_ok=True)
        for j, img in enumerate(gen_imgs):
            vutils.save_image(img, f'{gen_img_dir}gen_img_{j}.png')

        fid = calculate_fid_given_paths([gen_img_dir, real_img_dir], 'init_models')
        fid_avg[i].append(fid)
        print(f'gen#{i}/epoch#{epoch} fid_score: {fid:0.2f}')

        # Delete the generator to free up memory
        del gen
        torch.cuda.empty_cache()

In [None]:
# fid_avg = []
# for epoch in range(epochs):
#     fid_z = torch.randn(FID_BATCH_SIZE, NOISE_DIM, 1,1).to(dev)
#     random_start = np.random.randint(len(test_imgs)-FID_BATCH_SIZE)
#     for i,gen in enumerate(generators.values()):
#         if epoch == 0:
#             fid_avg.append([])
#         gen_imgs = gen(fid_z.detach())

#         # # Save generated images
#         # for j, img in enumerate(gen_imgs):
#         #     vutils.save_image(img, f'test/generator_imgs/gen_{i}_epoch_{j}_{epoch}.png')
        
#         # # Save real images
#         # real_imgs = test_imgs[random_start:random_start+FID_BATCH_SIZE]
#         # for j, img in enumerate(real_imgs):
#         #     vutils.save_image(img,  f'test/real_imgs/real_epoch_{j}_{epoch}.png')

#         mu_gen, sigma_gen = calculate_activation_statistics(gen_imgs, fic_model, batch_size=FID_BATCH_SIZE,cuda=True)
#         mu_test, sigma_test = calculate_activation_statistics(test_imgs[random_start:random_start+FID_BATCH_SIZE], fic_model, batch_size=FID_BATCH_SIZE,cuda=True)
#         fid = calculate_frechet_distance(mu_gen, sigma_gen, mu_test, sigma_test)
#         fid_avg[i].append(fid)
#         print(f'gen#{i}/epoch#{epoch} fid_score: {fid:0.2f}')


# fid_avg = []
# for epoch in range(epochs):
#     fid_z = torch.randn(FID_BATCH_SIZE, NOISE_DIM, 1,1).to(dev)
#     random_start = np.random.randint(len(test_imgs)-FID_BATCH_SIZE)
#     real_imgs = test_imgs[random_start:random_start+FID_BATCH_SIZE]

#     # Save real images once per epoch
#     real_img_dir = f'test/real_imgs/epoch_{epoch}/'
#     os.makedirs(real_img_dir, exist_ok=True)
#     for j, img in enumerate(real_imgs):
#         vutils.save_image(img,  f'{real_img_dir}real_img_{j}.png')

#     for i,gen in enumerate(generators.values()):
#         if epoch == 0:
#             fid_avg.append([])
#         gen_imgs = gen(fid_z.detach())

#         # Save generated images
#         gen_name = file_loc[i].split('/')[1]  # Use the second name in the breadcrumb
#         gen_img_dir = f'test/generator_imgs/{gen_name}/epoch_{epoch}/'
#         os.makedirs(gen_img_dir, exist_ok=True)
#         for j, img in enumerate(gen_imgs):
#             vutils.save_image(img, f'{gen_img_dir}gen_img_{j}.png')

#         fid = calculate_fid_given_paths([gen_img_dir, real_img_dir], 'init_models')
#         fid_avg[i].append(fid)
#         print(f'gen#{i}/epoch#{epoch} fid_score: {fid:0.2f}')

In [None]:
fid_avg_np = np.array(fid_avg)
print(fid_avg_np.mean(axis=1))
print(fid_avg_np.std(axis=1))

In [None]:
df = pd.DataFrame()
df['gen'] = [i+1 for i in range(gen_count_to_test)]
df['mean'] = fid_avg_np.mean(axis=1).astype(int)
df['sd'] = fid_avg_np.std(axis=1)
 

In [None]:
x = np.arange(1,len(generators)+1)
bars = plt.bar(x,np.around(fid_avg_np.mean(axis=1)),yerr=fid_avg_np.std(axis=1),ecolor='black',capsize=10)
plt.bar_label(bars,label_type='edge')
plt.xticks(x)
plt.xlabel('generator number')
plt.ylabel('FID average')
plt.ylim([200,240])
# plt.savefig('{}/worker_cont.png'.format(logger.writer.logdir))
plt.show()

In [None]:
df.to_csv('generator-results.csv',index=False)