In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install git+https://github.com/richzhang/PerceptualSimilarity

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-g6a_z4is
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-g6a_z4is
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting 

In [None]:
"""
================================
Batch-evaluate multiple reconstruction folders against EEG-ImageNet ground truth.

For each subfolder under `test_results_dir`, it computes:
  - SSIM, PSNR, LPIPS
  - CLIP cosine, n-way identification accuracy
  - CAT score
  - CFID

"""
from __future__ import annotations
import os, glob, csv, math
from pathlib import Path
from typing import Dict, List, Tuple
from scipy.linalg import sqrtm as scipy_sqrtm

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

import lpips
import clip

# ----------------------------- setup -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# create once (saves time)
lpips_model = lpips.LPIPS(net="vgg").to(device)
clip_model, _ = clip.load("ViT-B/32", device=device)

# ----------------------------- utils -----------------------------

def load_image(path: str | Path, size: int = 224) -> torch.Tensor:
    """Load RGB image → tensor in [-1,1], resized to `size`."""
    img = Image.open(path).convert("RGB").resize((size, size), Image.BICUBIC)
    arr = np.asarray(img).astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(arr).permute(2, 0, 1)

def mse2psnr(mse: float, max_val: float = 255.0) -> float:
    return 10.0 * math.log10((max_val**2) / (mse + 1e-10))

def compute_ssim_psnr(real_imgs: torch.Tensor, recon_imgs: torch.Tensor) -> Tuple[float, float]:
    """Return batch means of SSIM and PSNR. Inputs in [-1,1], shape (B,C,H,W)."""
    assert real_imgs.shape == recon_imgs.shape
    real_np  = (real_imgs.detach().cpu().numpy() + 1) * 127.5
    recon_np = (recon_imgs.detach().cpu().numpy() + 1) * 127.5
    s_vals, p_vals = [], []
    for gt, rn in zip(real_np, recon_np):
        s_val = ssim(gt.transpose(1, 2, 0), rn.transpose(1, 2, 0),
                     channel_axis=2, data_range=255)
        s_vals.append(s_val)
        mse = np.mean((gt - rn) ** 2)
        p_vals.append(mse2psnr(mse))
    return float(np.mean(s_vals)), float(np.mean(p_vals))

def embed_images_clip(img_batch: torch.Tensor, clip_model) -> torch.Tensor:
    """Encode a batch of images in [-1,1] with CLIP ViT-B/32 → L2-normalized feats."""
    imgs = (img_batch + 1) / 2
    imgs = F.interpolate(imgs, size=(224, 224), mode="bicubic", align_corners=False)
    clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=imgs.device).view(1, 3, 1, 1)
    clip_std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=imgs.device).view(1, 3, 1, 1)
    imgs_norm = (imgs - clip_mean) / clip_std
    with torch.no_grad():
        feats = clip_model.encode_image(imgs_norm)
    return F.normalize(feats, dim=-1)

def compute_clip_cosine(real_feats: torch.Tensor, recon_feats: torch.Tensor) -> float:
    cos = (real_feats * recon_feats).sum(dim=-1)
    return float(cos.mean().cpu())

def identification_accuracy(recon_feats: torch.Tensor,
                            real_feats: torch.Tensor,
                            n_way: int = 200,
                            rnd_seed: int = 42) -> float:
    """n-way image identification accuracy using cosine similarity."""
    np.random.seed(rnd_seed)
    num = recon_feats.size(0)
    hits = 0
    feat_np  = real_feats.cpu().numpy()
    recon_np = recon_feats.cpu().numpy()
    for i in range(num):
        idx_pool = np.delete(np.arange(num), i)
        k = min(n_way - 1, len(idx_pool))
        distractors = np.random.choice(idx_pool, size=k, replace=False)
        cand_idx = np.concatenate(([i], distractors))
        sims = recon_np[i] @ feat_np[cand_idx].T
        pred = cand_idx[np.argmax(sims)]
        hits += int(pred == i)
    return hits / num

def cat_score(pred_labels: List[str], true_fine: List[str], fine2coarse: Dict[str, str]) -> float:
    true_coarse = [fine2coarse[f] for f in true_fine]
    pred_coarse = [fine2coarse[p] for p in pred_labels]
    fine_acc   = np.mean([p == t for p, t in zip(pred_labels, true_fine)])
    coarse_acc = np.mean([p == t for p, t in zip(pred_coarse, true_coarse)])
    return float((fine_acc + coarse_acc) / 2.0)

# -------- CFID helpers --------

def frechet_distance(mu1: torch.Tensor,
                     sigma1: torch.Tensor,
                     mu2: torch.Tensor,
                     sigma2: torch.Tensor) -> float:
    """
    Fréchet distance between N(mu1, sigma1) and N(mu2, sigma2).
    Uses scipy.linalg.sqrtm on CPU float32 for stability.
    """
    diff = (mu1 - mu2).cpu().numpy()
    diff_sq = float(diff.dot(diff))

    s1 = sigma1.float().cpu()
    s2 = sigma2.float().cpu()
    tr1 = float(torch.trace(s1).item())
    tr2 = float(torch.trace(s2).item())

    cov_prod = (s1 @ s2).numpy()
    cov_sqrt = scipy_sqrtm(cov_prod)
    if np.iscomplexobj(cov_sqrt):
        cov_sqrt = cov_sqrt.real
    tr_cov_sqrt = float(np.trace(cov_sqrt))

    return diff_sq + tr1 + tr2 - 2 * tr_cov_sqrt

def compute_cfid(real_feats: torch.Tensor,
                 recon_feats: torch.Tensor,
                 labels: List[str]) -> float:
    """Conditional FID averaged across classes (skipping tiny classes)."""
    labels_np = np.array(labels)
    classes = np.unique(labels_np)
    dists, weights = [], []
    for cls in classes:
        idx = np.where(labels_np == cls)[0]
        if len(idx) < 10:
            continue
        r = real_feats[idx]
        f = recon_feats[idx]
        mu_r, mu_f = r.mean(0), f.mean(0)
        d = r.shape[1]
        eps = 1e-6
        I = torch.eye(d)
        sigma_r = torch.cov(r.T).float() + eps * I
        sigma_f = torch.cov(f.T).float() + eps * I
        d_cf = frechet_distance(mu_r, sigma_r, mu_f, sigma_f)
        dists.append(d_cf)
        weights.append(len(idx))
    return float(np.average(dists, weights=weights)) if dists else float("nan")

# ------------------------- evaluation core -------------------------

def evaluate_folder(real_dir: str,
                    recon_dir: str,
                    labels_csv: str,
                    n_id: int = 200,
                    device: str = "cuda",
                    batch_size: int = 128):
    # ---- read CSV (order matters) ----
    fnames, fine_labels, fine2coarse = [], [], {}
    with open(labels_csv) as f:
        rdr = csv.DictReader(f)
        for r in rdr:
            fnames.append(r["filename"])
            fine_labels.append(r["fine_label"])
            fine2coarse[r["fine_label"]] = r["coarse_label"]

    # ---- build real & recon paths from CSV order ----
    real_paths = [os.path.join(real_dir, fine_labels[i], fnames[i]) for i in range(len(fnames))]

    # recon files can be nested and include subject suffix; match by base name + wildcard
    recon_paths: List[str] = []
    for real_fname in fnames:
        base = os.path.splitext(real_fname)[0]  # e.g. "n02655020_3570"
        found: List[str] = []
        for ext in ("png", "jpg", "jpeg", "PNG", "JPG", "JPEG"):
            found += glob.glob(os.path.join(recon_dir, "**", f"{base}_*.{ext}"), recursive=True)
        if not found:
            raise FileNotFoundError(f"No reconstruction found for {base} under {recon_dir}")
        recon_paths.append(sorted(found)[0])  # pick the first match (or choose by subject)

    assert len(real_paths) == len(recon_paths), "1-to-1 alignment failed."

    # ---- accumulators (weighted by batch size) ----
    n_imgs = 0
    ssim_sum = 0.0
    psnr_sum = 0.0
    lpips_sum = 0.0
    real_feats_list, recon_feats_list = [], []

    # ---- loop ----
    total_batches = (len(real_paths) + batch_size - 1) // batch_size
    for start in tqdm(range(0, len(real_paths), batch_size),
                      total=total_batches,
                      unit="batch",
                      desc=os.path.basename(recon_dir)):
        rp = real_paths[start:start + batch_size]
        gp = recon_paths[start:start + batch_size]
        batch_real = torch.stack([load_image(p) for p in rp]).to(device)
        batch_recon = torch.stack([load_image(p) for p in gp]).to(device)
        bsz = batch_real.size(0)

        # SSIM & PSNR (batch means)
        sr, pr = compute_ssim_psnr(batch_real, batch_recon)
        ssim_sum += sr * bsz
        psnr_sum += pr * bsz

        # LPIPS
        with torch.no_grad():
            d = lpips_model(batch_recon, batch_real).view(-1).mean().item()
        lpips_sum += d * bsz

        # CLIP features
        fr = embed_images_clip(batch_real, clip_model)
        fg = embed_images_clip(batch_recon, clip_model)
        real_feats_list.append(fr.cpu())
        recon_feats_list.append(fg.cpu())

        n_imgs += bsz

    # ---- aggregate ----
    real_feats = torch.cat(real_feats_list, dim=0)
    recon_feats = torch.cat(recon_feats_list, dim=0)
    SSIM  = ssim_sum / n_imgs
    PSNR  = psnr_sum / n_imgs
    LPIPS = lpips_sum / n_imgs

    # ---- semantic / distribution metrics ----
    clip_cos = compute_clip_cosine(real_feats, recon_feats)
    id_acc   = identification_accuracy(recon_feats, real_feats, n_way=n_id)

    with torch.no_grad():
        label_list = list(fine2coarse.keys())
        text_tokens = clip.tokenize(label_list).to(device)
        text_feats = clip_model.encode_text(text_tokens)
        text_feats = F.normalize(text_feats, dim=-1)
        sims = recon_feats.to(device) @ text_feats.T
        pred_idx = sims.argmax(dim=-1).cpu().numpy()
        pred_labels = [label_list[i] for i in pred_idx]

    cat = cat_score(pred_labels, fine_labels, fine2coarse)
    cfid_val = compute_cfid(real_feats, recon_feats, fine_labels)

    # ---- report ----
    print(f"Folder: {os.path.basename(recon_dir)}")
    print(f" SSIM: {SSIM:.4f}, PSNR: {PSNR:.2f}, LPIPS: {LPIPS:.4f}")
    print(f" CLIP cos: {clip_cos:.4f}, {n_id}-way ID: {id_acc*100:.4f}%")
    print(f" CAT: {cat*100:.4f}%, CFID: {cfid_val:.4f}\n")

# ----------------------------- runner -----------------------------

if __name__ == "__main__":
    # Example multi-split sweep:
    for g in ["fine0", "fine1", "fine2", "fine3", "fine4", "coarse"]:
        real_dir = "/content/drive/MyDrive/ImageNet_Images/images_80class"
        labels_csv = f"/content/drive/MyDrive/ImageNet_Images/latent_dumps/granularity/Time/all_channels/{g}/eeg_imagenet_test_labels.csv"
        test_results_dir = f"/content/drive/MyDrive/ImageNet_Images/testResults/Time/{g}"

        for sub in sorted(os.listdir(test_results_dir)):
            p = os.path.join(test_results_dir, sub)
            if os.path.isdir(p):
                evaluate_folder(real_dir, p, labels_csv, n_id=200, device=device, batch_size=128)


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 214MB/s]


Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/vgg.pth


100%|███████████████████████████████████████| 338M/338M [00:08<00:00, 40.0MiB/s]
ACGAN_reconstructions_fine0: 100%|██████████| 3/3 [01:17<00:00, 25.92s/batch]


Folder: ACGAN_reconstructions_fine0
 SSIM: 0.1943, PSNR: 8.52, LPIPS: 0.8986
 CLIP cos: 0.6182, 200-way ID: 16.5775%
 CAT: 17.1123%, CFID: 0.7648



DCGAN_reconstructions_fine0: 100%|██████████| 3/3 [00:18<00:00,  6.31s/batch]


Folder: DCGAN_reconstructions_fine0
 SSIM: 0.2190, PSNR: 8.54, LPIPS: 0.8946
 CLIP cos: 0.6113, 200-way ID: 0.0000%
 CAT: 8.5561%, CFID: 0.7772



cProGAN_reconstructions_fine0: 100%|██████████| 3/3 [00:19<00:00,  6.66s/batch]


Folder: cProGAN_reconstructions_fine0
 SSIM: 0.2010, PSNR: 8.27, LPIPS: 0.8839
 CLIP cos: 0.5957, 200-way ID: 8.5561%
 CAT: 27.8075%, CFID: 0.8081



capsGAN_reconstructions_fine0: 100%|██████████| 3/3 [00:19<00:00,  6.58s/batch]


Folder: capsGAN_reconstructions_fine0
 SSIM: 0.2157, PSNR: 8.47, LPIPS: 0.9043
 CLIP cos: 0.5957, 200-way ID: 8.5561%
 CAT: 0.0000%, CFID: 0.8091



ACGAN_reconstructions: 100%|██████████| 3/3 [01:16<00:00, 25.34s/batch]


Folder: ACGAN_reconstructions
 SSIM: 0.2091, PSNR: 10.60, LPIPS: 0.8010
 CLIP cos: 0.4844, 200-way ID: 3.1915%
 CAT: 7.9787%, CFID: 1.0267



DCGAN_reconstructions: 100%|██████████| 3/3 [00:18<00:00,  6.04s/batch]


Folder: DCGAN_reconstructions
 SSIM: 0.2055, PSNR: 10.00, LPIPS: 0.8117
 CLIP cos: 0.5112, 200-way ID: 8.5106%
 CAT: 7.9787%, CFID: 0.9789



cProGAN_reconstructions: 100%|██████████| 3/3 [00:17<00:00,  5.97s/batch]


Folder: cProGAN_reconstructions
 SSIM: 0.2403, PSNR: 10.75, LPIPS: 0.7807
 CLIP cos: 0.5713, 200-way ID: 21.8085%
 CAT: 0.0000%, CFID: 0.8578



capsGAN_reconstructions: 100%|██████████| 3/3 [00:19<00:00,  6.65s/batch]


Folder: capsGAN_reconstructions
 SSIM: 0.1508, PSNR: 7.50, LPIPS: 0.8783
 CLIP cos: 0.4536, 200-way ID: 3.1915%
 CAT: 7.9787%, CFID: 1.0921



ACGAN_reconstructions: 100%|██████████| 3/3 [01:32<00:00, 30.89s/batch]


Folder: ACGAN_reconstructions
 SSIM: 0.1887, PSNR: 9.21, LPIPS: 0.8487
 CLIP cos: 0.4963, 200-way ID: 8.1967%
 CAT: 16.3934%, CFID: 1.0078



DCGAN_reconstructions: 100%|██████████| 3/3 [00:17<00:00,  5.97s/batch]


Folder: DCGAN_reconstructions
 SSIM: 0.1879, PSNR: 9.09, LPIPS: 0.8196
 CLIP cos: 0.5273, 200-way ID: 8.1967%
 CAT: 16.3934%, CFID: 0.9431



cProGAN_reconstructions: 100%|██████████| 3/3 [00:17<00:00,  5.91s/batch]


Folder: cProGAN_reconstructions
 SSIM: 0.2036, PSNR: 9.04, LPIPS: 0.8302
 CLIP cos: 0.5464, 200-way ID: 8.1967%
 CAT: 16.3934%, CFID: 0.9087



capsGAN_reconstructions: 100%|██████████| 3/3 [00:19<00:00,  6.61s/batch]


Folder: capsGAN_reconstructions
 SSIM: 0.1347, PSNR: 7.54, LPIPS: 0.8732
 CLIP cos: 0.5156, 200-way ID: 16.3934%
 CAT: 16.3934%, CFID: 0.9665



ACGAN_reconstructions: 100%|██████████| 15/15 [03:29<00:00, 13.96s/batch]


Folder: ACGAN_reconstructions
 SSIM: 0.2407, PSNR: 8.64, LPIPS: 0.8241
 CLIP cos: 0.5762, 200-way ID: 1.6146%
 CAT: 1.6146%, CFID: 0.8300



DCGAN_reconstructions: 100%|██████████| 15/15 [01:37<00:00,  6.50s/batch]


Folder: DCGAN_reconstructions
 SSIM: 0.2261, PSNR: 8.22, LPIPS: 0.8332
 CLIP cos: 0.5630, 200-way ID: 1.7761%
 CAT: 24.4349%, CFID: 0.8504



cProGAN_reconstructions: 100%|██████████| 15/15 [01:36<00:00,  6.46s/batch]


Folder: cProGAN_reconstructions
 SSIM: 0.2746, PSNR: 8.29, LPIPS: 0.7603
 CLIP cos: 0.5684, 200-way ID: 0.2153%
 CAT: 11.3025%, CFID: 0.8580



capsGAN_reconstructions: 100%|██████████| 15/15 [01:45<00:00,  7.04s/batch]


Folder: capsGAN_reconstructions
 SSIM: 0.2362, PSNR: 8.14, LPIPS: 0.8428
 CLIP cos: 0.5547, 200-way ID: 0.0538%
 CAT: 8.0732%, CFID: 0.8712



ACGAN_reconstructions_fine4: 100%|██████████| 3/3 [01:12<00:00, 24.02s/batch]


Folder: ACGAN_reconstructions_fine4
 SSIM: 0.2355, PSNR: 10.24, LPIPS: 0.7750
 CLIP cos: 0.5654, 200-way ID: 16.3934%
 CAT: 8.1967%, CFID: 0.8689



DCGAN_reconstructions_fine4: 100%|██████████| 3/3 [00:18<00:00,  6.09s/batch]


Folder: DCGAN_reconstructions_fine4
 SSIM: 0.1952, PSNR: 8.93, LPIPS: 0.8435
 CLIP cos: 0.5557, 200-way ID: 0.0000%
 CAT: 8.1967%, CFID: 0.8881



cProGAN_reconstructions_fine4: 100%|██████████| 3/3 [00:18<00:00,  6.28s/batch]


Folder: cProGAN_reconstructions_fine4
 SSIM: 0.2512, PSNR: 10.14, LPIPS: 0.7676
 CLIP cos: 0.5708, 200-way ID: 16.3934%
 CAT: 8.1967%, CFID: 0.8588



capsGAN_reconstructions_fine4: 100%|██████████| 3/3 [00:20<00:00,  6.83s/batch]


Folder: capsGAN_reconstructions_fine4
 SSIM: 0.2026, PSNR: 8.83, LPIPS: 0.8699
 CLIP cos: 0.5303, 200-way ID: 8.1967%
 CAT: 8.1967%, CFID: 0.9385



ACGAN_reconstructions: 100%|██████████| 37/37 [14:23<00:00, 23.34s/batch]


Folder: ACGAN_reconstructions
 SSIM: 0.1940, PSNR: 9.42, LPIPS: 0.8262
 CLIP cos: 0.5430, 200-way ID: 1.2956%
 CAT: 2.5912%, CFID: 0.9127



DCGAN_reconstructions: 100%|██████████| 37/37 [05:24<00:00,  8.77s/batch]


Folder: DCGAN_reconstructions
 SSIM: 0.2004, PSNR: 8.39, LPIPS: 0.8638
 CLIP cos: 0.5420, 200-way ID: 0.6910%
 CAT: 1.6195%, CFID: 0.8847



cProGAN_reconstructions: 100%|██████████| 37/37 [05:21<00:00,  8.69s/batch]


Folder: cProGAN_reconstructions
 SSIM: 0.2206, PSNR: 8.99, LPIPS: 0.8322
 CLIP cos: 0.5503, 200-way ID: 0.9069%
 CAT: 2.5912%, CFID: 0.8937



capsGAN_reconstructions: 100%|██████████| 37/37 [06:08<00:00,  9.97s/batch]


Folder: capsGAN_reconstructions
 SSIM: 0.1876, PSNR: 8.52, LPIPS: 0.8633
 CLIP cos: 0.5366, 200-way ID: 0.0648%
 CAT: 2.5912%, CFID: 0.9046

