In [9]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from models import Generator
from torchmetrics.wrappers import FeatureShare
from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance, InceptionScore
from torchmetrics.functional.image import perceptual_path_length
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cuda device


In [16]:
BATCH_SIZE = 200


# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar10_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar10_dataset, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified


In [17]:
def compute_inception_metrics(generator, num_samples=100, fid_features=2048, label=None):
    fs = FeatureShare([FrechetInceptionDistance(feature=fid_features), KernelInceptionDistance(subset_size=100)]).to(device)
    fs.reset()
    
    inception = InceptionScore().to(device)
    inception.reset()
    
    original_label = label
    
    i = 0
    while i < num_samples:
        real_images_batch = []
        fake_images_batch = []
        for _, (real_images, real_labels) in enumerate(train_loader):
            real_images = real_images.to(device)
            if not original_label:
                label = torch.randint(low=0, high=10, size=(real_images.size(0),))
            else:
                label = torch.tensor([original_label])
                mask = real_labels == label
                real_images = real_images[mask]
                label = torch.tensor([original_label] * real_images.size(0))
            label = label.to(device)
            one_hot_labels = F.one_hot(label, 10)
            input_vector = torch.cat((torch.randn(real_images.size(0), 100).to(device), one_hot_labels.float()), dim=1)
            fake_images = generator(input_vector)

            # Resize real and fake images to 32x32
            real_images = F.interpolate(real_images, size=32, mode='bilinear', align_corners=False)
            fake_images = F.interpolate(fake_images, size=32, mode='bilinear', align_corners=False)

            # Denormalize the images (assuming they were normalized to [-1, 1])
            real_images = ((real_images + 1) / 2 * 255).to(torch.uint8)
            fake_images = ((fake_images + 1) / 2 * 255).to(torch.uint8)

            real_images_batch.append(real_images)
            fake_images_batch.append(fake_images)

            if len(real_images_batch) >= BATCH_SIZE:
                break

        # Concatenate batches if there are enough images, else continue to next batch
        if len(real_images_batch) >= BATCH_SIZE:
            real_images_batch = torch.cat(real_images_batch, dim=0)[:BATCH_SIZE]
            fake_images_batch = torch.cat(fake_images_batch, dim=0)[:BATCH_SIZE]
            fs.update(real_images_batch, True)
            fs.update(fake_images_batch, False)
            inception.update(fake_images_batch)
            i += len(real_images_batch)

        print(f"Processed {i}/{num_samples} samples")

    score = fs.compute()
    inception_score = inception.compute()
    score["InceptionScore"] = inception_score
    print(score)
    return score["FrechetInceptionDistance"], score["KernelInceptionDistance"], score["InceptionScore"]

In [None]:
generator = Generator(100).to(device)
generator.load_state_dict(torch.load("final_models/final_generator.pth"))

avg_fid = 0
avg_kid_mean = 0
avg_kid_std = 0
avg_is_mean = 0
avg_is_std = 0
for _ in range(3):
    fid, kid, inception = compute_inception_metrics(generator, num_samples=200, label=7)
    avg_fid += fid
    avg_kid_mean += kid[0]
    avg_kid_std += kid[1]
    avg_is_mean += inception[0]
    avg_is_std += inception[1]
print(f"Average FID: {avg_fid / 3}")
print(f"Average KID mean: {avg_kid_mean / 3}, Average KID std: {avg_kid_std / 3}")
print(f"Average IS mean: {avg_is_mean / 3}, Average IS std: {avg_is_std / 3}")

In [14]:
class GeneratorWrapper(Generator):
    def __init__(self, z_size, num_classes=10):
        self.latent_size = z_size
        super().__init__(z_size, num_classes=num_classes)

    def forward(self, z):
        return 255 * (super().forward(z) * 0.5 + 0.5)
    
    def sample(self, num_samples):
        label = torch.randint(0, 10, (num_samples,))
        one_hot_labels = F.one_hot(label, 10)
        noise = torch.randn(num_samples, self.latent_size, device=self.fc1.weight.device)
        input_vector = torch.cat((noise, one_hot_labels.float()), dim=1)
        return input_vector

In [15]:
generator = GeneratorWrapper(100).to(device)
generator.load_state_dict(torch.load("final_models/final_generator.pth"))
generator = generator.to(device)
generator.eval()

perceptual_path_length(generator, num_samples=100, batch_size=50)

TEST


(tensor(39.2204),
 tensor(27.6649),
 tensor([ 32.2516,  21.8606,  35.5805,  56.8346, 103.4226,   9.8458,  45.6617,
          97.4959,  30.9462,  21.3177, 120.9710,  60.5314,  23.0506,  17.4871,
          37.4160,  86.5625,  25.3680,  12.3684,  14.4844,  33.7865,  20.4293,
          18.6896,  18.6774,  36.0100,  23.8001,  18.8495,  25.5616,  42.2301,
          14.0406,  50.3401,  29.9364,  66.0843,  28.7563,  62.2913,  27.8636,
          33.1642,  12.3069,  33.7130,  33.2749,  38.8485,  29.1089,  42.7706,
          78.3030,  41.3270,  37.8709,  15.1425,  45.8934,  26.6008,  42.9832,
          21.2979,  22.9904,  56.7789,  37.0110,  44.1315,  17.6006,  11.0275,
          13.7830,  29.3422,  11.6531,   9.0502,  36.4368,  13.6588, 129.3887,
          60.8706,  13.0150,  54.9913,  53.8828,  50.6271,  11.9265,  64.1133,
          23.0048,  48.7219,  55.5936,  46.1058,  43.3842,  34.6218,  31.3851,
          45.9221,  10.3538,  24.9722,  49.8425,  21.5106,  52.3321,   9.1489,
          87.850