In [3]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from models import Generator
from torchmetrics.wrappers import FeatureShare
from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance, InceptionScore, PerceptualPathLength
from torchmetrics.functional.image import perceptual_path_length
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [2]:
BATCH_SIZE = 200


# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar10_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar10_dataset, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified


In [3]:
def compute_inception_metrics(generator, num_samples=100, fid_features=2048):
    fs = FeatureShare([FrechetInceptionDistance(feature=fid_features), KernelInceptionDistance(subset_size=100)]).to(device)
    fs.reset()
    
    inception = InceptionScore().to(device)
    inception.reset()
    
    i = 0
    while i < num_samples:
        real_images_batch = []
        fake_images_batch = []
        for _, (real_images, _) in enumerate(train_loader):
            real_images = real_images.to(device)
            fake_images = generator(torch.randn(real_images.size(0), 100, 1, 1).to(device))

            # Resize real and fake images to 32x32
            real_images = F.interpolate(real_images, size=32, mode='bilinear', align_corners=False)
            fake_images = F.interpolate(fake_images, size=32, mode='bilinear', align_corners=False)

            # Denormalize the images (assuming they were normalized to [-1, 1])
            real_images = ((real_images + 1) / 2 * 255).to(torch.uint8)
            fake_images = ((fake_images + 1) / 2 * 255).to(torch.uint8)

            real_images_batch.append(real_images)
            fake_images_batch.append(fake_images)

            if len(real_images_batch) >= BATCH_SIZE:
                break

        # Concatenate batches if there are enough images, else continue to next batch
        if len(real_images_batch) >= BATCH_SIZE:
            real_images_batch = torch.cat(real_images_batch, dim=0)[:BATCH_SIZE]
            fake_images_batch = torch.cat(fake_images_batch, dim=0)[:BATCH_SIZE]
            fs.update(real_images_batch, True)
            fs.update(fake_images_batch, False)
            inception.update(fake_images_batch)
            i += len(real_images_batch)

        print(f"Processed {i}/{num_samples} samples")

    score = fs.compute()
    inception_score = inception.compute()
    score["InceptionScore"] = inception_score
    print(score)
    return score["FrechetInceptionDistance"], score["KernelInceptionDistance"], score["InceptionScore"]

In [4]:
generator = Generator(100).to(device)
generator.load_state_dict(torch.load("final_models/final_generator.pth"))

avg_fid = 0
avg_kid_mean = 0
avg_kid_std = 0
avg_is_mean = 0
avg_is_std = 0
for _ in range(3):
    fid, kid, inception = compute_inception_metrics(generator, num_samples=10000)
    avg_fid += fid
    avg_kid_mean += kid[0]
    avg_kid_std += kid[1]
    avg_is_mean += inception[0]
    avg_is_std += inception[1]
print(f"Average FID: {avg_fid / 3}")
print(f"Average KID mean: {avg_kid_mean / 3}, Average KID std: {avg_kid_std / 3}")
print(f"Average IS mean: {avg_is_mean / 3}, Average IS std: {avg_is_std / 3}")



Processed 200/10000 samples
Processed 400/10000 samples
Processed 600/10000 samples
Processed 800/10000 samples
Processed 1000/10000 samples
Processed 1200/10000 samples
Processed 1400/10000 samples
Processed 1600/10000 samples
Processed 1800/10000 samples
Processed 2000/10000 samples
Processed 2200/10000 samples
Processed 2400/10000 samples
Processed 2600/10000 samples
Processed 2800/10000 samples
Processed 3000/10000 samples
Processed 3200/10000 samples
Processed 3400/10000 samples
Processed 3600/10000 samples
Processed 3800/10000 samples
Processed 4000/10000 samples
Processed 4200/10000 samples
Processed 4400/10000 samples
Processed 4600/10000 samples
Processed 4800/10000 samples
Processed 5000/10000 samples
Processed 5200/10000 samples
Processed 5400/10000 samples
Processed 5600/10000 samples
Processed 5800/10000 samples
Processed 6000/10000 samples
Processed 6200/10000 samples
Processed 6400/10000 samples
Processed 6600/10000 samples
Processed 6800/10000 samples
Processed 7000/100

In [4]:
class GeneratorWrapper(Generator):
    def __init__(self, z_size):
        super().__init__(z_size)

    def forward(self, z):
        return 255 * (super().forward(z) * 0.5 + 0.5)
    
    def sample(self, num_samples):
        noise = torch.randn(
            num_samples, self.latent_size, 1, 1, device=self.main[0].weight.device
        )
        return noise

In [6]:
generator = GeneratorWrapper(100).to(device)
generator.load_state_dict(torch.load("final_models/final_generator.pth"))
generator = generator.to(device)
generator.eval()

mean, std, _ = perceptual_path_length(generator)

In [7]:
print(mean, std)

tensor(10.1979) tensor(6.0251)
