In [None]:
import os
import random
import torch
import numpy as np 
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms
from tqdm.notebook import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance

from dataset import init_dataset

In [None]:
from config import EvalConfig
config = EvalConfig()

# Dataset

In [None]:
# Set manual seed so that we have the same train/test split
torch.manual_seed(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)

# Evaluation

In [None]:
config.pretrained_model_dir = 'results/unetSM/mnist-7/epoch-100'
config.folder_name = 'ddim_fake_images'
config.dataset_name = "~/.pytorch/MNIST_data/"
config.labels = [7]

### Frechet Inception Distance (FID)
* Requires directory of fake images
* Requires dataloader reconstruction

In [None]:
# List all fake images from the corresponding directory 
fake_images_dir = f'./{config.pretrained_model_dir}/{config.folder_name}'
fake_images_list = os.listdir(fake_images_dir)
n_fake_images = len(fake_images_list)

# List (via datalaoder) all real images from the corresponding dataset
config.eval_batch_size = 64 

# Create eval dataloader
eval_dataloader = torch.utils.data.DataLoader(
    init_dataset(config.dataset_name, split='train', labels=config.labels), 
    batch_size=config.eval_batch_size, shuffle=True)

# Define FID metric
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(config.device)

# Iterate over all fake images
n_images_to_eval = min(len(fake_images_list), len(eval_dataloader.dataset))
for batch_idx in tqdm(range(0, n_images_to_eval, config.eval_batch_size), desc='Calculating FID...'):
    
    # Get the corresponding real images
    real_images, _ = next(iter(eval_dataloader))
    real_images = real_images.to(config.device)

    # Get the fake images
    fake_images = [read_image(f"{fake_images_dir}/{i}") for i in fake_images_list[batch_idx:batch_idx+config.eval_batch_size]]
    fake_images = torch.stack(fake_images).to(config.device) 
    fake_images = fake_images.float() / 255.0
    
    # Keep the real images of the same size
    real_images = real_images[:fake_images.shape[0]]

    # Update the FID metric
    fid.update(real_images, real=True)
    fid.update(fake_images, real=False)

# Compute the FID score
print(fid.compute())
fid.reset()

### Precision Recall

In [None]:
# List all real images (via datalaoder) from the corresponding dataset
eval_dataloader = torch.utils.data.DataLoader(
    init_dataset(config.dataset_name, split='train', labels=config.labels), 
    batch_size=config.eval_batch_size, shuffle=True)

real_images = torch.cat([image for image, _ in eval_dataloader])

In [None]:
# List all fake images from the corresponding directory 
fake_images_dir = f'./{config.pretrained_model_dir}/{config.folder_name}'
fake_images_list = os.listdir(fake_images_dir)
fake_images = torch.stack([read_image(f'{fake_images_dir}/{i}') for i in fake_images_list])
fake_images = fake_images.float() / 255.0

In [None]:
n_images_to_eval = min(real_images.shape[0], fake_images.shape[0])
real_images = real_images[:n_images_to_eval]
fake_images = fake_images[:n_images_to_eval]

In [None]:
from improved_precision_recall import IPR 

# Define IPR metric
ipr = IPR(batch_size=8, k=3, num_samples=n_images_to_eval, device='cuda')

# Compute Manifold 
ipr.compute_manifold_ref(real_images)

In [None]:
metric = ipr.precision_and_recall(fake_images)
# Print results
print('precision =', metric.precision)
print('recall =', metric.recall)

# r_score = ipr.realism(fake_images)
# print('realism =', r_score)

### Classifier

In [None]:
from torchvision import models
import torch.nn as nn

checkpoint = torch.load('./results/eval_classifier/checkpoint_2n7.pth')

# Load the actual and the opposite mapping
map_labels = checkpoint['mapping']
map_labels_inv = {v: k for k, v in map_labels.items()}

# Define the pretrained model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(config.labels))

model.load_state_dict(checkpoint['state_dict'])
model = model.to(config.device)

In [None]:
# List all fake images from the corresponding directory 
fake_images_dir = f'./{config.pretrained_model_dir}/{config.folder_name}'
fake_images_list = os.listdir(fake_images_dir)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import math

n_counts = {label:0 for label in map_labels.keys()}

for b_idx in tqdm(range(0, len(fake_images_list), config.eval_batch_size)):
    
    fake_images_names = fake_images_list[b_idx:b_idx+config.eval_batch_size]
    fake_images = [ read_image(f'{fake_images_dir}/{fi_name}') for fi_name in fake_images_names ]
    fake_images = torch.stack(fake_images)
    fake_images = fake_images.float() / 255.0

    # Get predictions
    outputs = model(fake_images.to(config.device))
    _, predicted = torch.max(outputs, 1)
    predicted = [map_labels_inv[label.item()] for label in predicted]
    for label in predicted: n_counts[label] += 1

# Compute the total count of all labels
total_count = sum(n_counts.values())
# Compute the frequency of each label
frequencies = {label: count / total_count for label, count in n_counts.items()}
print("Frequencies:", frequencies)

# Compute the entropy
entropy = -sum([p * math.log(p) for p in frequencies.values()])
print("Entropy:", entropy)

### Visualize

In [None]:
import matplotlib.pyplot as plt
from diffusers.utils import make_image_grid

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

fakes = [transforms.ToPILImage()(image) for image in fake_images[:32]]
grid = make_image_grid(fakes, rows=4, cols=8)
axes[0].imshow(grid); axes[0].set_title(f'Fake images')

reals = [transforms.ToPILImage()(image) for image in real_images[:32]]
grid = make_image_grid(reals, rows=4, cols=8)
axes[1].imshow(grid); axes[1].set_title(f'Real images')