In [1]:
import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
from torch import nn
from gan import Generator, Discriminator

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed
set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
class MNISTDataset(Dataset):
    def __init__(self, data, label, transform=None):
        self.data = data
        self.label = label
        self.transform = transform
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transform(self.data[idx]), self.label[idx]

In [4]:
mnist = loadmat("../mnist-original.mat/mnist-original.mat")
mnist_data = mnist["data"].T
mnist_label = mnist["label"][0]

In [5]:
EPOCHS = 10
batch_size = 64
num_workers = 8
lr = 2e-4
k = 1
latent_dim = 128


In [6]:
generator = Generator(latent_dim).to(device)
generator.load_state_dict(torch.load( f"/mnt/d/data/mnist_model/GAN_BCE/g_lr_{lr}_epoch_{EPOCHS}_latent{latent_dim}_k_{k}.pth", map_location=device))

  generator.load_state_dict(torch.load( f"/mnt/d/data/mnist_model/GAN_BCE/g_lr_{lr}_epoch_{EPOCHS}_latent{latent_dim}_k_{k}.pth", map_location=device))


<All keys matched successfully>

In [7]:
result = torch.empty(10000, 1, 28, 28)
result = result.to(device)
with torch.no_grad():
    generator.eval()
    # Generate 100 batches of 100 images each
    for i in range(100):
        print(f"Generating batch {i+1}/100")
        z = torch.randn(100, generator.latent_dim).to(device)
        output = generator(z)
        result[i*100:(i+1)*100] = output
    

Generating batch 1/100
Generating batch 2/100
Generating batch 3/100
Generating batch 4/100
Generating batch 5/100
Generating batch 6/100
Generating batch 7/100
Generating batch 8/100
Generating batch 9/100
Generating batch 10/100
Generating batch 11/100
Generating batch 12/100
Generating batch 13/100
Generating batch 14/100
Generating batch 15/100
Generating batch 16/100
Generating batch 17/100
Generating batch 18/100
Generating batch 19/100
Generating batch 20/100
Generating batch 21/100
Generating batch 22/100
Generating batch 23/100
Generating batch 24/100
Generating batch 25/100
Generating batch 26/100
Generating batch 27/100
Generating batch 28/100
Generating batch 29/100
Generating batch 30/100
Generating batch 31/100
Generating batch 32/100
Generating batch 33/100
Generating batch 34/100
Generating batch 35/100
Generating batch 36/100
Generating batch 37/100
Generating batch 38/100
Generating batch 39/100
Generating batch 40/100
Generating batch 41/100
Generating batch 42/100
G

In [8]:
result = (result-result.min())/(result.max()-result.min())

In [12]:
import tifffile as tif
import os
result_dir = "/mnt/d/data/mnist_result/gan_bce_result/"
os.makedirs(result_dir, exist_ok=True)
from os.path import join as ospj
for i in range(result.shape[0]):
    tif.imwrite(ospj(result_dir, f"{str(i).zfill(5)}.tif"), result[i].squeeze().cpu().numpy())

In [13]:
test_dir = "/mnt/d/data/mnist_result/test_sample/"
result_dir = "/mnt/d/data/mnist_result/gan_bce_result/"

real_images_folder = test_dir
# generated_images_folder = './FID_app3'
generated_images_folder = result_dir
import torch
from pytorch_fid import fid_score

fid_value = fid_score.calculate_fid_given_paths([real_images_folder, generated_images_folder],
                                                batch_size=50,
                                                device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                                                dims=2048                                                 
                                                )
print("FID score:", fid_value)

100%|██████████| 200/200 [00:44<00:00,  4.50it/s]
100%|██████████| 200/200 [00:54<00:00,  3.68it/s]


FID score: 2.8555386485814287
