In [6]:
import sys
sys.path.append('../../../patronus/')
# print(sys.path)
from global_config import * # load REPO_HOME_DIR, DATASET_DIR

import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os 


from train.utils import load_patronus_unet_model
from models.diffusion import SimpleDiffusion
from train.dataloader import get_dataloader,get_dataloader_pact
from analysis.analysis_utils import get_samples_from_loader, vis_samples

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prototype quality
We followed InfoDiffusion for accessing prototype (latent) quality, which contains of two parts:  

a) *Prototype capability in semantic representation*:  
Meaning the performance of using latents predicting known semantic attributes.   
b) *Prototype disentanglement*:  
Meaning the performance of using signle dimension of latents predicting known semantic attributes. 

## 1 - Save the latent of certain train version (it might take a while)

In [1]:
ds = 'CelebA'
version_num = 4
print(f'Selecting dataset: {ds} version {version_num}')

Selecting dataset: CelebA version 4


In [2]:
# ---- Load the patronus model -----
print('*'*30 + 'Load model' + '*'*30)
model, patronus_config_set = load_patronus_unet_model(ds_name=ds, 
                                                    version_num=version_num,
                            )
model.to(device)
model.eval()

# if the latent npz file exists, load it
latent_path = os.path.join(REPO_HOME_DIR, f'records/save_latents/{ds}-{version_num}/{ds}_{version_num}_latent.npz')
if os.path.exists(latent_path):
    print(f'Latent representation already exists at {latent_path}.')
else:
    print(f'Latent representation does not exist at {latent_path}. Retrive and save it.')

    # ---- Get the prototype encoder -----
    prototype_encoder = model.proactBlock
    prototype_encoder.eval()
    prototype_encoder.to(device)

    # get the training data (prototype activations)

    # get the prototype activation for the training data and wrap it as dataloader
    dataloader_train_pact = get_dataloader_pact(dataset_name=f'{ds}-train',
        batch_size=128,
        pact_encoder=prototype_encoder,
        device=device,
        shuffle = False,  # to not need to shuffle here
    )

    print(f'{len(dataloader_train_pact)=}')

    # get the latent representation for the training data
    all_pact_train = []
    all_attr = []
    for i, (x, (extra_info)) in tqdm(enumerate(dataloader_train_pact)):
        pact_train_batch = x
        pact_train_batch = pact_train_batch.view(pact_train_batch.shape[0], -1)
        pact_train_batch = pact_train_batch.cpu().detach().numpy()  # Convert to NumPy
        
        # Process attributes
        label = extra_info[0]
        # print(f'{label=}')
        if 'CelebA' in ds or 'ffhq256' in ds or 'CHEXPERT' in ds:
            label_stacked = torch.stack(label, dim=1)  # Shape [batch_size, num_attributes]
        else:
            label_stacked = label
        all_attr_batch = label_stacked.cpu().detach().numpy()  # Convert to NumPy
            # Accumulate results
        all_pact_train.append(pact_train_batch)
        all_attr.append(all_attr_batch)
        if (i + 1) % 10 == 0:
            print(f"Processed {i+1} batches...")    
            
    all_pact_train = np.concatenate(all_pact_train, axis=0)  # Shape: [total_samples, feature_dim]
    all_attr = np.concatenate(all_attr, axis=0)  # Shape: [total_samples, num_attributes]



    print(f'{all_pact_train.shape=}')
    print(f'{all_attr.shape=}')



    # save it to npz file
    save_dir = REPO_HOME_DIR + f'records/save_latents/{ds}-{version_num}/'
    os.makedirs(save_dir, exist_ok=True)
    np.savez(save_dir+"{}_{}_latent".format(ds,version_num), all_a = all_pact_train, all_attr = all_attr)
    print(f'Saved latent representation of {ds} - version {version_num} to {save_dir}.')



## 2 - Run evaluation
Notice that disentanglement part are only available for CelebA and CheXpert dataset.

In [1]:
from p_quality_tool import eval_disentanglement
eval_disentanglement(ds_name=ds, version_num=version_num)