In [3]:
%load_ext autoreload
%autoreload 2
import sys 
sys.path.append('../')
sys.path.append('../src/')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import argparse
import glob
import os
from typing import Literal, Optional
from PIL import Image
import pandas as pd
from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from pytorch_fid.fid_score import compute_statistics_of_path, calculate_frechet_distance, calculate_fid_given_paths
from pytorch_fid.inception import InceptionV3
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from scripts.eval import calculate_fid, calculate_cmmd, calculate_mse, compute_embeddings_for_dir, ClipEmbeddingModel
from csbm.models.quantized_images import Codec

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class CelebaDataset:
    transform: Optional[transforms.Compose] = None
    
    def __init__(
        self, 
        sex: Literal['male', 'female'], 
        data_dir: str,
        size: Optional[int] = None, 
        train: bool = True,
        return_names: bool = False,
        count: int = 0
    ):
        self.train = train
        self.size = size
        self.return_names = return_names

        attrs = pd.read_csv(os.path.join(data_dir, 'celeba', 'list_attr_celeba.csv'))
        if sex == 'male':
            attrs = attrs[attrs['Male'] != -1] # only males
        else:
            attrs = attrs[attrs['Male'] == -1]
        image_names = attrs['image_id'].tolist()
        self.dataset = [os.path.join(data_dir, 'celeba', 'img_align_celeba', 'raw', image) for image in image_names][-count:]

    def __getitem__(self, index):
        # transform = transforms.Compose([
        #     transforms.Resize((self.size, self.size)),
        #     transforms.ToTensor(),
        # ])
        image = Image.open(self.dataset[index])
        image = image.convert('RGB')
        # image = transform(image)

        if self.return_names:
           return image, self.dataset[index].split('/')[-1]
        return image

    def __len__(self):
        return len(self.dataset)

In [7]:
ref_data_path = '../data/celeba/female_test'
stats_ref_path = os.path.join(ref_data_path, 'fid_stats.npz')
embd_ref_path = os.path.join(ref_data_path, 'cmmd_embed.npy')

# gen_data_path = '../experiments/quantized_images/uniform/dim_128_aplha_0.01_14.01.25_21:22:30/checkpoints/forward_10/generation'
gen_data_path = '../experiments/quantized_images/uniform/small_dim_128_aplha_0.01_20.01.25_16:43:26/checkpoints/forward_4/generation'

dims = 2048
batch_size = 32
num_workers = 1
max_count = int(0.1 * len(CelebaDataset('female', '../data/'))) 

In [8]:
if os.path.exists(stats_ref_path):
    stats_ref = np.load(stats_ref_path)
else:
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx]).to(device)
    m, s = compute_statistics_of_path(
        os.path.join(ref_data_path),
        model, batch_size, dims, device, num_workers, max_count
    )
    stats_ref = {'mu': m, 'sigma': s}
    np.savez(stats_ref_path, mu=m, sigma=s)

In [9]:
if os.path.exists(embd_ref_path):
    embs_ref = np.load(embd_ref_path).astype("float32")
else:
    embs_ref = compute_embeddings_for_dir(
        os.path.join(ref_data_path),
        ClipEmbeddingModel(), batch_size, max_count
    ).astype("float32")
    np.save(embd_ref_path, embs_ref)

In [10]:
for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']: # ['dim_128_aplha_0.01_14.01.25_21:22:30']: #, 'small_dim_128_aplha_0.01_20.01.25_16:43:26']: #'tiny_dim_128_aplha_0.01_17.01.25_22:02:58', 'tiny_dim_128_aplha_0.01_19.01.25_21:21:21']:
    for iteration in range(4, 5):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        fid = calculate_fid(
            eval_dir=gen_data_path,
            stats_ref=stats_ref,
            dims=dims,
            batch_size=batch_size,
            num_workers=num_workers,
            device=device,
            max_count=max_count
        )
        print(f'Iter: {iteration}, FID: {fid}')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:44<00:00,  8.26it/s]


Iter: 3, FID: 8.245995455934349


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 370/370 [00:30<00:00, 11.94it/s]


Iter: 4, FID: 10.600888926217465


In [11]:
for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']: 
    fid = FrechetInceptionDistance(feature=2048)
    for iteration in range(3, 5):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        fid = calculate_fid(
            eval_dir=gen_data_path,
            stats_ref=stats_ref,
            dims=dims,
            batch_size=batch_size,
            num_workers=num_workers,
            device=device,
            max_count=max_count
        )
        print(f'Iter: {iteration}, FID: {fid}')

  4%|███▌                                                                                                   | 13/370 [00:01<00:32, 11.15it/s]


KeyboardInterrupt: 

In [11]:
for exp_name in ['dim_128_aplha_0.005_27.01.25_21:56:36']: # ['dim_128_aplha_0.01_14.01.25_21:22:30']: #, 'small_dim_128_aplha_0.01_20.01.25_16:43:26']: #'tiny_dim_128_aplha_0.01_17.01.25_22:02:58', 'tiny_dim_128_aplha_0.01_19.01.25_21:21:21']:
    for iteration in range(4, 5):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        cmmd = calculate_cmmd(
            eval_dir=gen_data_path,
            embs_ref=embs_ref,
            batch_size=batch_size,
            max_count=max_count
        )
        print(f'CMMD: {cmmd}')

Calculating embeddings for 9344 images from ../experiments/quantized_images/uniform/dim_128_aplha_0.005_27.01.25_21:56:36/checkpoints/forward_3/generation.


  0%|          | 0/292 [00:00<?, ?it/s]

CMMD: 0.15425682067871094
Calculating embeddings for 11816 images from ../experiments/quantized_images/uniform/dim_128_aplha_0.005_27.01.25_21:56:36/checkpoints/forward_4/generation.


  0%|          | 0/369 [00:00<?, ?it/s]

CMMD: 0.16450881958007812


In [7]:
for exp_name in ['small_dim_128_aplha_0.01_20.01.25_16:43:26']:
    for iteration in range(2, 7):
        gen_data_path = f'../experiments/quantized_images/uniform/{exp_name}/checkpoints/forward_{iteration}/generation'
        mse = calculate_mse(
            eval_dir=gen_data_path,
            ref_dir='../data/celeba/',
            batch_size=batch_size,
            num_workers=num_workers
        )
        print(f'Iter: {iteration}, : {mse}')

  0%|          | 0/506 [00:00<?, ?it/s]

Iter: 2, : 0.029523000946706575


  0%|          | 0/506 [00:00<?, ?it/s]

Iter: 3, : 0.01858686597710519


  0%|          | 0/506 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x155467428a40>
Traceback (most recent call last):
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Iter: 4, : 0.014704428473130447


  0%|          | 0/506 [00:00<?, ?it/s]

Iter: 5, : 0.012771712402777892


Exception ignored in: 

  0%|          | 0/506 [00:00<?, ?it/s]

<function _MultiProcessingDataLoaderIter.__del__ at 0x155467428a40>
Traceback (most recent call last):
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/trinity/home/g.ksenofontov/anaconda3/envs/disc_sbm/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Iter: 6, : 0.0114071473567727


In [None]:
vector_quantizer_config = {
    'config_path': '../configs/vqgan_celeba_f8_1024.yaml',
    'ckpt_path': '../checkpoints/vqgan_celeba_f8_1024.ckpt',
}
vq = Codec(**vector_quantizer_config)