In [1]:
import torch
import numpy as np 
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms
from tqdm.notebook import tqdm
from dataclasses import dataclass
from datasets import load_dataset
from diffusers import UNet2DModel

In [2]:
@dataclass
class TrainingConfig:
    device = "cuda"
    dataset_name = "mnist"
    labels = [1, 2]
    image_size = 32  # the generated image resolution
    train_batch_size = 128
    eval_batch_size = 128  # how many images to sample during evaluation
    num_epochs = 500
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 10
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "/experiments/mnist1n2s-ep500"  # the model name locally and on the HF Hub

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

In [3]:
preprocess = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor()
])

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"image": images, "label": examples["label"]}

def filter_dataset(dataset, labels):
    return dataset.filter(lambda example: example["label"] in labels)

In [4]:
test_dataset = load_dataset(config.dataset_name, split="test")
test_dataset.set_transform(transform)
test_dataset = filter_dataset(test_dataset, config.labels)
print(test_dataset)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=config.eval_batch_size, shuffle=False)

Dataset({
    features: ['image', 'label'],
    num_rows: 2167
})


In [5]:
model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

# Load pretrained model
# model = UNet2DModel.from_pretrained(f'./{config.output_dir}', subfolder="unet", use_safetensors=True)
model = model.to(config.device)

In [None]:
# Load all real images of test_dataloader into a tensor
y_real = []
for batch in test_dataloader:
    y_real += [ batch["image"] ]
y_real = torch.cat(y_real, dim=0)
# Convert to uint8 for FID calculation
y_real = (y_real * 255).to(torch.uint8)

y_random = torch.randint(0, 256, size=y_real.shape, dtype=torch.uint8)

# Load all fake images from directory into a tensor
fake_images_dir = f"./{config.output_dir}/fake_images"
# iterate over the images in the directory
y_fake = []
for i in range(y_real.shape[0]):
    y_fake += [ read_image(f'{fake_images_dir}/{i}.png') ]

y_fake = torch.stack(y_fake)

print(f'Number of real images: {y_real.shape[0]}')
print(f'Number of fake images: {y_fake.shape[0]}')

### Frechet Inception Distance (FID)

In [None]:
_ = torch.manual_seed(123)

from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=2048)

# iterate over batch of images
bs = config.eval_batch_size
fid_scores = []
for b_idx in tqdm(range(0, y_real.shape[0], bs)):
    fid.update(y_real[b_idx : b_idx + bs], real=True)
    fid.update(y_fake[b_idx : b_idx + bs], real=False)

print(fid.compute())
fid.reset()

### Entropy (Variance)

In [None]:
from torchvision import models
from torch import nn 

# Load pretrained model
model = models.resnet50(pretrained=True)
# Change the last layer to output 10 classes
model.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(2048, 10, bias=True)

# Load the mnist pre-trained model
model.load_state_dict(torch.load('mnist-classifier.pth'))

# Load the model to device
model = model.to(config.device)

In [None]:
from utils import entropy as H

# Iterate over 64 batch images in y_fake 
y_fake_labels = []
temp = []
for idx in range(0, y_fake.shape[0], config.eval_batch_size):
    # Get the batch of fake images
    fake_images = y_fake[idx : idx + config.eval_batch_size]
    # Send the batch of fake images to device
    fake_images = fake_images.to(config.device)
    # Normalize the images
    fake_images = fake_images / 255
    # Get the model prediction
    labels = model(fake_images)
    # Append the prediction to the list
    y_fake_labels += [ torch.argmax(labels, dim=1) ]

# Concatenate the list of predictions into a tensor
y_fake_labels = torch.cat(y_fake_labels, dim=0)

# Find the proportion of distinct labels in y_fake_labels with torch
_, counts = torch.unique(y_fake_labels, return_counts=True)
proportion = counts / torch.sum(counts)

# Calculate the entropy of the proportion with torch
entropy = torch.sum(-proportion * torch.log2(proportion))
entropy