In [None]:
!pip install torch-fidelity
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import os
import random
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch_fidelity import calculate_metrics
import clip
from tqdm import tqdm

def reservoir_sample_dir(root_dir, sample_size, exts=(".png", ".jpg", ".jpeg")):
    reservoir = []
    for i, entry in enumerate(os.scandir(root_dir)):
        if not entry.is_file():
            continue
        name = entry.name.lower()
        if not name.endswith(exts):
            continue
        path = entry.path
        if len(reservoir) < sample_size:
            reservoir.append(path)
        else:
            j = random.randint(0, i)
            if j < sample_size:
                reservoir[j] = path
    return reservoir

class PathListDataset(Dataset):
    def __init__(self, paths, size=(299, 299)):
        self.paths = paths
        self.size  = size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = img.resize(self.size, resample=Image.BILINEAR)
        arr = np.array(img, dtype=np.uint8)    # H×W×C
        t   = torch.from_numpy(arr)            # ByteTensor H×W×C
        return t.permute(2, 0, 1)              # ByteTensor C×H×W


real_dir = "/content/drive/MyDrive/images2/images2"
gen_dir  = "/content/drive/MyDrive/generated_images"

sample_n     = 50000
real_sample  = reservoir_sample_dir(real_dir, sample_n)
gen_sample   = reservoir_sample_dir(gen_dir,  sample_n)

print(f"Sampled {len(real_sample)} real paths, {len(gen_sample)} gen paths.")

ds_real_samp = PathListDataset(real_sample)
ds_gen_samp  = PathListDataset(gen_sample)

metrics = calculate_metrics(
    input1=ds_real_samp,
    input2=ds_gen_samp,
    cuda=True,
    isc=False,
    fid=True,
    kid=True,
    num_workers=8,
    batch_size=512,
)

print("Frechet Inception Distance:", metrics["frechet_inception_distance"])
print(
    "Kernel Inception Distance:",
    f"{metrics['kernel_inception_distance_mean']:.6f} ± "
    f"{metrics['kernel_inception_distance_std']:.6f}"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def get_clip_embeddings(paths, batch_size=32):
    embs = []
    for i in tqdm(range(0, len(paths), batch_size), desc="CLIP batches"):
        batch = [preprocess(Image.open(p).convert("RGB")) for p in paths[i : i + batch_size]]
        bt    = torch.stack(batch).to(device, non_blocking=True)
        with torch.no_grad(), torch.cuda.amp.autocast():
            e = model.encode_image(bt)
            e = e / e.norm(dim=-1, keepdim=True)
        embs.append(e.cpu())
    return torch.cat(embs, dim=0)

def median_heuristic_sigma(x, y):
    with torch.no_grad():
        z      = torch.cat([x, y], dim=0)
        sample = z[torch.randperm(len(z))][: min(len(z), 1000)]
        d2     = torch.cdist(sample, sample, p=2).reshape(-1)
        return torch.median(d2).item()

def compute_mmd(x, y, sigma=None):
    m, n = x.size(0), y.size(0)
    if sigma is None:
        sigma = median_heuristic_sigma(x, y)
    xx = torch.cdist(x, x, p=2).pow(2)
    yy = torch.cdist(y, y, p=2).pow(2)
    xy = torch.cdist(x, y, p=2).pow(2)
    Kxx = torch.exp(-xx / (2 * sigma ** 2))
    Kyy = torch.exp(-yy / (2 * sigma ** 2))
    Kxy = torch.exp(-xy / (2 * sigma ** 2))
    term_x  = (Kxx.sum() - Kxx.trace()) / (m * (m - 1))
    term_y  = (Kyy.sum() - Kyy.trace()) / (n * (n - 1))
    term_xy = 2 * Kxy.sum() / (m * n)
    return term_x + term_y - term_xy

print("\nComputing CLIP embeddings for real images...")
real_emb = get_clip_embeddings(real_sample)

print("\nComputing CLIP embeddings for generated images...")
gen_emb = get_clip_embeddings(gen_sample)

sigma = median_heuristic_sigma(real_emb, gen_emb)
print(f"Using sigma = {sigma:.4f}")

cmmd2 = compute_mmd(real_emb, gen_emb, sigma=sigma)
print(f"CMMD² (CLIP-based MMD): {cmmd2:.6f}")
