In [1]:
import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
from pixcnn import *

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed
set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [3]:
class MNISTDataset(Dataset):
    def __init__(self, data, label, transform=None):
        self.data = data
        self.label = label
        self.transform = transform
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.transform(self.data[idx]), self.label[idx]

In [4]:
mnist = loadmat("../mnist-original.mat/mnist-original.mat")
mnist_data = mnist["data"].T
mnist_label = mnist["label"][0]

In [5]:
EPOCHS = 100
batch_size = 64
num_workers = 8
lr = 5e-4

model = PixelCNN().to(device)


In [6]:
from matplotlib import pyplot as plt
import torchvision.utils as vutils
def show_images(images):
    grid = vutils.make_grid(images, nrow=4, padding=2, normalize=True)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu(), cmap='gray')  # 调整维度顺序
    plt.axis('off')
    plt.show()
    plt.close()

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.load_state_dict(torch.load( f"/mnt/d/data/mnist_model/pixcnn_lr_{lr}_epoch_{EPOCHS}.pth", map_location=device))


  model.load_state_dict(torch.load( f"/mnt/d/data/mnist_model/pixcnn_lr_{lr}_epoch_{EPOCHS}.pth", map_location=device))


<All keys matched successfully>

In [8]:
from torch.utils.data import random_split
from torchvision import transforms

transform = transforms.Compose([
    transforms.Lambda(lambda x: x.reshape(28, 28)),
    transforms.ToTensor(),
    transforms.Pad((2, 2, 2, 2))  # Pad to 32x32
])
# Assuming mnist_data and mnist_label are already defined
dataset = MNISTDataset(mnist_data, mnist_label, transform=transform)

# Define the lengths for training and validation sets
train_size = 50000
val_size = 10000
# train_size = 500
# val_size = 100
test_size = len(dataset) - train_size - val_size
# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Now you can use train_dataset and val_dataset with DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [10]:
result = torch.empty(10000, 1, 32, 32)
H, W, C = 32, 32, 1

with torch.no_grad():
    model.eval()
    for i in range(100):
        print(f"Generating batch {i+1}/100")
        output = torch.zeros(100, C, H, W).to(device)
        for h in range(H):
            for w in range(W):
                for c in range(C):
                    # Feed the whole array and retrieving the pixel value probabilities for the next pixel.
                    logits = model(output)[:, c, h, w]
                    probs = logits.sigmoid()
                    # Use the probabilities to pick pixel values and append the values to the image frame.
                    output[:, c, h, w] = torch.bernoulli(probs)
        result[i*100:(i+1)*100] = output

Generating batch 1/100
Generating batch 2/100
Generating batch 3/100
Generating batch 4/100
Generating batch 5/100
Generating batch 6/100
Generating batch 7/100
Generating batch 8/100
Generating batch 9/100
Generating batch 10/100
Generating batch 11/100
Generating batch 12/100
Generating batch 13/100
Generating batch 14/100
Generating batch 15/100
Generating batch 16/100
Generating batch 17/100
Generating batch 18/100
Generating batch 19/100
Generating batch 20/100
Generating batch 21/100
Generating batch 22/100
Generating batch 23/100
Generating batch 24/100
Generating batch 25/100
Generating batch 26/100
Generating batch 27/100
Generating batch 28/100
Generating batch 29/100
Generating batch 30/100
Generating batch 31/100
Generating batch 32/100
Generating batch 33/100
Generating batch 34/100
Generating batch 35/100
Generating batch 36/100
Generating batch 37/100
Generating batch 38/100
Generating batch 39/100
Generating batch 40/100
Generating batch 41/100
Generating batch 42/100
G

In [11]:
import tifffile as tif
import os
result_dir = "/mnt/d/data/mnist_result/pixcnn_result/"
os.makedirs(result_dir, exist_ok=True)
from os.path import join as ospj
for i in range(result.shape[0]):
    tif.imwrite(ospj(result_dir, f"{str(i).zfill(5)}.tif"), result[i][0, 2:-2, 2:-2].squeeze().cpu().numpy())
    

In [12]:
test_dir = "/mnt/d/data/mnist_result/test_sample/"
result_dir = "/mnt/d/data/mnist_result/pixcnn_result/"
import torchvision.transforms as transforms

real_images_folder = test_dir
# generated_images_folder = './FID_app3'
generated_images_folder = result_dir
import torch
import torchvision
from pytorch_fid import fid_score

inception_model = torchvision.models.inception_v3(pretrained=True)

fid_value = fid_score.calculate_fid_given_paths([real_images_folder, generated_images_folder],
                                                batch_size=50,
                                                device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                                                dims=2048                                                 
                                                )
print("FID score:", fid_value)


100%|██████████| 200/200 [01:13<00:00,  2.73it/s]
100%|██████████| 200/200 [01:05<00:00,  3.03it/s]


FID score: 5.982401717675742
