In [1]:
import os
import pickle
import numpy as np
from pathlib import Path
import torch
from torchvision.utils import save_image

def load_cifar_batch(file):
    with open(file, 'rb') as f:
        dict_ = pickle.load(f, encoding='latin1')
    data = dict_['data']  # shape: (10000, 3072)
    data = data.reshape(-1, 3, 32, 32)  # (N, C, H, W)
    data = torch.from_numpy(data).float() / 255.0
    return data

# load all 5 training batches
root = Path("../HW_4/data/cifar-10-batches-py")
all_imgs = []
for i in range(1, 6):
    batch_file = root / f"data_batch_{i}"
    all_imgs.append(load_cifar_batch(batch_file))

all_imgs = torch.cat(all_imgs, dim=0)  # (50000, 3, 32, 32)
print("Total real images:", all_imgs.shape)

# randomly pick 128
idx = torch.randperm(all_imgs.size(0))[:128]
real_subset = all_imgs[idx]

out_dir = Path("real_cifar")
out_dir.mkdir(exist_ok=True)

for i, img in enumerate(real_subset):
    save_image(img, out_dir / f"{i:06d}.png")

print("Saved 128 real images to", out_dir)

Total real images: torch.Size([50000, 3, 32, 32])
Saved 128 real images to real_cifar


In [2]:
# pip install torch_fidelity

In [None]:
import torch
from torch_fidelity import calculate_metrics

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

def compute_for_model(gen_dir, real_dir="real_cifar"):
    metrics_dict = calculate_metrics(
        input1=gen_dir,          # generated images
        input2=real_dir,         # real images (for FID)
        cuda=torch.cuda.is_available(),
        batch_size=64,
        fid=True,                        # FID
        # isc=True,                        # Inception Score
    )
    print(f"Results for {gen_dir}:")
    for k, v in metrics_dict.items():
        print(f"  {k}: {v:.4f}")
    print()
    return metrics_dict

if __name__ == "__main__":
    ddpm_metrics  = compute_for_model("outputs/epoch_039")

Using device: cuda


Creating feature extractor "inception-v3-compat" with features ['2048']
Extracting statistics from input 1
Looking for samples non-recursivelty in "outputs/epoch_039" with extensions png,jpg,jpeg
Found 128 samples
Processing samples                                                        
Extracting statistics from input 2
Looking for samples non-recursivelty in "real_cifar" with extensions png,jpg,jpeg
Found 128 samples
Processing samples                                                        
