In [9]:
import os
import torch
import numpy as np 
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms
from tqdm.notebook import tqdm
from dataclasses import dataclass
from datasets import load_dataset
from diffusers import UNet2DModel

In [2]:
from train_config import TrainingConfig

config = TrainingConfig()

In [3]:
preprocess = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor()
])

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"image": images, "label": examples["label"]}

def filter_dataset(dataset, labels):
    return dataset.filter(lambda example: example["label"] in labels)

In [4]:
test_dataset = load_dataset(config.dataset_name, split="test")
test_dataset.set_transform(transform)

if config.labels is not None:
    test_dataset = filter_dataset(test_dataset, config.labels)
    print(test_dataset)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=config.eval_batch_size, shuffle=False)

In [5]:
model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

# Load pretrained model
model = UNet2DModel.from_pretrained(f'./{config.output_dir}', subfolder="unet", use_safetensors=True)
model = model.to(config.device)

In [15]:
# Load all fake images from directory into a tensor
fake_images_dir = f"./{config.output_dir}/fake_images"
# Count the number of images in directory
n_fake_images = len(os.listdir(fake_images_dir))
# iterate over the images in the directory
y_fake = []
for i in range(n_fake_images):
    y_fake += [ read_image(f'{fake_images_dir}/{i}.png') ]

# Load all real images of test_dataloader into a tensor
y_real = []
for batch in test_dataloader:
    y_real += [ batch["image"] ]
# Concatenate all batches into one tensor
y_real = torch.cat(y_real, dim=0)
# Only take the same number of real images as fake images
y_real = y_real[:n_fake_images]
# Convert to uint8 for FID calculation
y_real = (y_real * 255).to(torch.uint8)

y_random = torch.randint(0, 256, size=y_real.shape, dtype=torch.uint8)

y_fake = torch.stack(y_fake)

print(f'Number of real images: {y_real.shape[0]}')
print(f'Number of fake images: {y_fake.shape[0]}')

Number of real images: 9728
Number of fake images: 9728


### Frechet Inception Distance (FID)

In [25]:
# send y_fake to float
y_fake = y_fake.to(torch.uint8)
y_real = y_real.to(torch.uint8)

In [26]:
_ = torch.manual_seed(123)

from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=2048)

# iterate over batch of images
bs = config.eval_batch_size
fid_scores = []
for b_idx in tqdm(range(0, y_real.shape[0], bs)):
    fid.update(y_real[b_idx : b_idx + bs], real=True)
    fid.update(y_fake[b_idx : b_idx + bs], real=False)

print(fid.compute())
fid.reset()

  0%|          | 0/76 [00:00<?, ?it/s]