In [1]:
%load_ext autoreload
%autoreload 2
import sys 
sys.path.append('scripts')
sys.path.append('src/')

In [3]:
import os
from typing import Optional
from PIL import Image
from tqdm.auto import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from pytorch_fid.fid_score import calculate_fid_given_paths

from csbm.data import BaseDataset
from csbm.metrics import CMMD, FID, MSE

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## FID

In [6]:
class SomeDataset(BaseDataset):
    transform: Optional[transforms.Compose] = None
    
    def __init__(
        self, 
        data_dir: str,
    ):
        self.data_dir= data_dir

        self.dataset = os.listdir(data_dir)
        self.dataset = [os.path.join(data_dir, x) for x in self.dataset]
        self.dataset = list(filter(lambda x: x.endswith('.jpg'), self.dataset))
        self.dataset = sorted(self.dataset)

    def __getitem__(self, index):
        transform = transforms.ToTensor()
        image = Image.open(self.dataset[index])
        image = image.convert('RGB')
        image = transform(image)
        return image

    def __len__(self):
        return len(self.dataset)

In [None]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:# , 'dim_128_aplha_0.01_14.01.25_21:22:30']:
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
    cmmd = calculate_fid_given_paths(
        paths=[ref_data_path, gen_data_path],
        dims=2048,
        batch_size=32,
        device=device,
    )
    print(f'Iter: {iteration}, FID: {cmmd}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:26<00:00, 14.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 506/506 [00:35<00:00, 14.08it/s]


Iter: 4, FID: 9.916346980044523


In [None]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    cmmd = FID().to(device)
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'

    dataloader = DataLoader(
        SomeDataset(data_dir=ref_data_path), batch_size=32
    )
    for real_images in tqdm(dataloader):
        real_images = real_images.to(device)
        cmmd.update(real_images, real=True)


    dataloader = DataLoader(
        SomeDataset(data_dir=gen_data_path), batch_size=32
    )
    for fake_images in tqdm(dataloader):
        fake_images = fake_images.to(device)
        cmmd.update(fake_images, real=False)
    
    print(f'Iter: {iteration}, FID: {cmmd.compute().detach().cpu().numpy()}')

  state_dict = torch.load(feature_extractor_weights_path)


  0%|          | 0/370 [00:00<?, ?it/s]

  0%|          | 0/506 [00:00<?, ?it/s]

Iter: 4, FID: 9.969728469848633


## CMMD

In [None]:
if os.path.exists(embd_ref_path):
    embs_ref = np.load(embd_ref_path).astype("float32")
else:
    embs_ref = compute_embeddings_for_dir(
        os.path.join(ref_data_path),
        ClipEmbeddingModel(), batch_size, max_count
    ).astype("float32")
    np.save(embd_ref_path, embs_ref)

In [None]:
for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']: # ['dim_128_aplha_0.01_14.01.25_21:22:30']: #, 'small_dim_128_aplha_0.01_20.01.25_16:43:26']: #'tiny_dim_128_aplha_0.01_17.01.25_22:02:58', 'tiny_dim_128_aplha_0.01_19.01.25_21:21:21']:
    for iteration in range(4, 5):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        cmmd = calculate_cmmd(
            eval_dir=gen_data_path,
            embs_ref=embs_ref,
            batch_size=batch_size,
            max_count=max_count
        )
        print(f'CMMD: {cmmd}')

In [9]:
iteration = 4
ref_data_path = 'data/celeba/female_test'

for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']:#, 'dim_128_aplha_0.01_14.01.25_21:22:30']: 
    cmmd = CMMD().to(device)
    gen_data_path = f'experiments/quantized_images/celeba/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'

    dataloader = DataLoader(
        SomeDataset(data_dir=ref_data_path), batch_size=32
    )
    for real_images in tqdm(dataloader):
        real_images = real_images.to(device)
        cmmd.update(real_images, real=True)


    dataloader = DataLoader(
        SomeDataset(data_dir=gen_data_path), batch_size=32
    )
    for fake_images in tqdm(dataloader):
        fake_images = fake_images.to(device)
        cmmd.update(fake_images, real=False)
    
    print(f'Iter: {iteration}, CMMD: {cmmd.compute().detach().cpu().numpy()}')

  0%|          | 0/370 [00:00<?, ?it/s]

  0%|          | 0/506 [00:00<?, ?it/s]

TypeError: expected np.ndarray (got Tensor)

## MSE

In [None]:
for exp_name in ['small_dim_128_aplha_0.01_20.01.25_16:43:26']:
    for iteration in range(2, 7):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        mse = calculate_mse(
            eval_dir=gen_data_path,
            ref_dir='../data/celeba/',
            batch_size=batch_size,
            num_workers=num_workers
        )
        print(f'Iter: {iteration}, : {mse}')