In [1]:
%load_ext autoreload
%autoreload 2
import sys 
sys.path.append('scripts')
sys.path.append('src/')

In [6]:
import os
from typing import Optional
from PIL import Image
from tqdm.auto import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from pytorch_fid.fid_score import calculate_fid_given_paths

from csbm.data import BaseDataset, CouplingDataset
from csbm.metrics import CMMD, FID, MSE, LPIPS

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [22]:
class SomeDataset(BaseDataset):
    transform: Optional[transforms.Compose] = None
    
    def __init__(
        self, 
        data_dir: str,
    ):
        self.data_dir= data_dir

        self.dataset = os.listdir(data_dir)
        self.dataset = [os.path.join(data_dir, x) for x in self.dataset]
        self.dataset = list(filter(lambda x: x.endswith('.jpg'), self.dataset))
        # self.dataset = sorted(self.dataset)

    def __getitem__(self, index):
        transform = transforms.ToTensor()
        image = Image.open(self.dataset[index])
        image = image.convert('RGB')
        image = transform(image)
        return image

    def __len__(self):
        return len(self.dataset)
    
    def repeat(self, n: int, max_len: int):
        self.dataset = self.dataset * n
        self.dataset = self.dataset[:max_len]

## FID

In [None]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:# , 'dim_128_aplha_0.01_14.01.25_21:22:30']:
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
    fid = calculate_fid_given_paths(
        paths=[ref_data_path, gen_data_path],
        dims=2048,
        batch_size=32,
        device=device,
    )
    print(f'Iter: {iteration}, FID: {fid}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:26<00:00, 14.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 506/506 [00:35<00:00, 14.08it/s]


Iter: 4, FID: 9.916346980044523


In [23]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    fid = FID().to(device)
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'

    dataset = CouplingDataset(SomeDataset(data_dir=ref_data_path), SomeDataset(data_dir=gen_data_path))
    dataloader = DataLoader(dataset, batch_size=32)
    for real_images, fake_images in tqdm(dataloader):
        real_images = real_images.to(device)
        fake_images = fake_images.to(device)
        fid.update(real_images, real=True)
        fid.update(fake_images, real=False)    
    print(f'Iter: {iteration}, FID: {fid.compute().detach().cpu().numpy()}')

  state_dict = torch.load(feature_extractor_weights_path)


  0%|          | 0/370 [00:00<?, ?it/s]

Iter: 4, FID: 11.367501258850098


## CMMD

In [None]:
if os.path.exists(embd_ref_path):
    embs_ref = np.load(embd_ref_path).astype("float32")
else:
    embs_ref = compute_embeddings_for_dir(
        os.path.join(ref_data_path),
        ClipEmbeddingModel(), batch_size, max_count
    ).astype("float32")
    np.save(embd_ref_path, embs_ref)

In [None]:
for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']: # ['dim_128_aplha_0.01_14.01.25_21:22:30']: #, 'small_dim_128_aplha_0.01_20.01.25_16:43:26']: #'tiny_dim_128_aplha_0.01_17.01.25_22:02:58', 'tiny_dim_128_aplha_0.01_19.01.25_21:21:21']:
    for iteration in range(4, 5):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        cmmd = calculate_cmmd(
            eval_dir=gen_data_path,
            embs_ref=embs_ref,
            batch_size=batch_size,
            max_count=max_count
        )
        print(f'CMMD: {cmmd}')

In [None]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    cmmd = CMMD().to(device)
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'

    dataset = CouplingDataset(SomeDataset(data_dir=ref_data_path), SomeDataset(data_dir=gen_data_path))
    dataloader = DataLoader(dataset, batch_size=32)
    for real_images, fake_images in tqdm(dataloader):
        real_images = real_images.to(device)
        fake_images = fake_images.to(device)
        cmmd.update(real_images, real=True)
        cmmd.update(fake_images, real=False)
    
    print(f'Iter: {iteration}, CMMD: {cmmd.compute().detach().cpu().numpy()}')

  0%|          | 0/506 [00:00<?, ?it/s]

Iter: 4, CMMD: 0.17380714416503906


## MSE

In [None]:
for exp_name in ['small_dim_128_aplha_0.01_20.01.25_16:43:26']:
    for iteration in range(2, 7):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        mse = calculate_mse(
            eval_dir=gen_data_path,
            ref_dir='../data/celeba/',
            batch_size=batch_size,
            num_workers=num_workers
        )
        print(f'Iter: {iteration}, : {mse}')

# LPIPS

In [21]:
iteration = 4
ref_data_path = 'data/celeba/male_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    lpip = LPIPS(normalize=True, reduction='mean').to(device)
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'

    dataset = CouplingDataset(SomeDataset(data_dir=ref_data_path), SomeDataset(data_dir=gen_data_path))
    dataloader = DataLoader(dataset, batch_size=32)
    for fake_images, real_images in tqdm(dataloader):
        fake_images = fake_images.to(device)
        real_images = real_images.to(device)
        lpip.update(fake_images, real_images)
    
    print(f'Iter: {iteration}, LPIPS: {lpip.compute().detach().cpu().numpy()}')

  0%|          | 0/264 [00:00<?, ?it/s]

Iter: 4, LPIPS: 0.17472831904888153


In [None]:
import lpips

iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    lpip = lpips.LPIPS(net='alex', lpips=True).to(device).eval()
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
    metric = 0

    real_loader = DataLoader(
        SomeDataset(data_dir=ref_data_path), batch_size=32
    )
    fake_loader = DataLoader(
        SomeDataset(data_dir=gen_data_path), batch_size=32
    )
    for fake_images, real_images in tqdm(zip(fake_loader, real_loader)):
        fake_images = 2 * fake_images.to(device) - 1
        real_images = 2 * real_images.to(device) - 1
        try:
            metric += lpip(fake_images, real_images).sum().item()
        except RuntimeError as e:
            print(f'Error: {e}')
            break
    
    print(f'Iter: {iteration}, LPIPS: {metric/11817}')

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /trinity/home/g.ksenofontov/anaconda3/envs/csbm/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


0it [00:00, ?it/s]

Error: The size of tensor a (32) must match the size of tensor b (9) at non-singleton dimension 0
Iter: 4, LPIPS: 0.4421703618341688
