In [None]:
import sys
sys.path.append('../../../patronus/')
from global_config import * # load REPO_HOME_DIR, DATASET_DIR

import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import os


from train.utils import load_patronus_unet_model
from models.diffusion import SimpleDiffusion
from train.dataloader import get_dataloader
from analysis.analysis_utils import get_samples_from_loader, vis_samples
from analysis.interpretability.visualize_prototype import plot_vis_p,get_most_activated_patch_for_one
from train.dataloader import inverse_transform
from train.dataloader import get_dataloader,get_dataloader_pact

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Diagnosis
## 1 - Load the model for evaluation 

In [None]:
ds = 'CelebA_hair_smile'
version_num = 2
print(f'Selecting dataset: {ds} version {version_num}')

In [None]:
# ---- Load the patronus model -----
print('*'*30 + 'Load model' + '*'*30)
model, patronus_config_set = load_patronus_unet_model(ds_name=ds, 
                                                    version_num=version_num,
                            )

 # ---- Load the diffusion schduler ----
sd = SimpleDiffusion(
        num_diffusion_timesteps = patronus_config_set['TrainingConfig']['TIMESTEPS'],
        img_shape               = patronus_config_set['TrainingConfig']['IMG_SHAPE'],
        device                  = patronus_config_set['BaseConfig']['DEVICE'],
)

## 2 - Save the latent and list the most relevent prototype given attribute

In [None]:
# ---- Load the patronus model -----
print('*'*30 + 'Load model' + '*'*30)
model, patronus_config_set = load_patronus_unet_model(ds_name=ds, 
                                                    version_num=version_num,
                            )
model.to(device)
model.eval()

# if the latent npz file exists, load it
latent_path = os.path.join(REPO_HOME_DIR, f'records/save_latents/{ds}-{version_num}/{ds}_{version_num}_latent.npz')
if os.path.exists(latent_path):
    print(f'Latent representation already exists at {latent_path}.')
else:
    print(f'Latent representation does not exist at {latent_path}. Retrive and save it.')

    # ---- Get the prototype encoder -----
    prototype_encoder = model.proactBlock
    prototype_encoder.eval()
    prototype_encoder.to(device)

    # get the training data (prototype activations)

    # get the prototype activation for the training data and wrap it as dataloader
    dataloader_train_pact = get_dataloader_pact(dataset_name=f'{ds}-train',
        batch_size=128,
        pact_encoder=prototype_encoder,
        device=device,
        shuffle = False,  # to not need to shuffle here
    )

    print(f'{len(dataloader_train_pact)=}')

    # get the latent representation for the training data
    all_pact_train = []
    all_attr = []
    for i, (x, (extra_info)) in tqdm(enumerate(dataloader_train_pact)):
        pact_train_batch = x
        pact_train_batch = pact_train_batch.view(pact_train_batch.shape[0], -1)
        pact_train_batch = pact_train_batch.cpu().detach().numpy()  # Convert to NumPy
        
        # Process attributes
        label = extra_info[0]
        # print(f'{label=}')
        if 'CelebA' in ds or 'ffhq256' in ds or 'CHEXPERT' in ds:
            label_stacked = torch.stack(label, dim=1)  # Shape [batch_size, num_attributes]
        else:
            label_stacked = label
        all_attr_batch = label_stacked.cpu().detach().numpy()  # Convert to NumPy
            # Accumulate results
        all_pact_train.append(pact_train_batch)
        all_attr.append(all_attr_batch)
        if (i + 1) % 10 == 0:
            print(f"Processed {i+1} batches...")    
            
    all_pact_train = np.concatenate(all_pact_train, axis=0)  # Shape: [total_samples, feature_dim]
    all_attr = np.concatenate(all_attr, axis=0)  # Shape: [total_samples, num_attributes]



    print(f'{all_pact_train.shape=}')
    print(f'{all_attr.shape=}')



    # save it to npz file
    save_dir = REPO_HOME_DIR + f'records/save_latents/{ds}-{version_num}/'
    os.makedirs(save_dir, exist_ok=True)
    np.savez(save_dir+"{}_{}_latent".format(ds,version_num), all_a = all_pact_train, all_attr = all_attr)
    print(f'Saved latent representation of {ds} - version {version_num} to {save_dir}.')


In [None]:
from analysis.evaluation.p_quality_tool import eval_disentanglement
auroc_score_all, y_names = eval_disentanglement(ds, version_num, return_auroc = True)


In [None]:
black_hair_index = y_names.index('Black_Hair')
brown_hair_index = y_names.index('Brown_Hair')
blonde_hair_index = y_names.index('Blond_Hair')
smiling_index = y_names.index('Smiling')

In [None]:
# print the most high 10 value of each index
hair_names = ['Black Hair', 'Brown Hair', 'Blonde Hair','Smiling']
auroc_result = auroc_score_all[0] # choose the first fold -- the result usually are not very different
top10_hair_smile_records = {}
for i,this_hair_index in enumerate([black_hair_index, brown_hair_index, blonde_hair_index,smiling_index]):
    print(hair_names[i])
    this_hair_auroc_result = auroc_result[this_hair_index,:]
    top_10_indices = np.argsort(this_hair_auroc_result)[-10:][::-1].tolist()  # Get the indices of the top 10 highest values

    # Get the top 10 highest values
    top_10_values = this_hair_auroc_result[top_10_indices]
    str_index = [str(i) for i in top_10_indices]

    print("Top 10 values:", top_10_values)
    print("Indices of top 10 values:", ','.join(str_index))

    top10_hair_smile_records[hair_names[i]] = top_10_indices

## 3 - visulize the selected prototypes accordingly
### 3.1 - select some samples

In [None]:
dataloader_test = get_dataloader(
        dataset_name=f'{ds}-test',
        batch_size=128,
        device='cpu',
        shuffle = True, # test set should not be shuffled
)

max_pact = 0.5

In [None]:
selected_sample_id = ['198406.jpg','184336.jpg','200877.jpg','196817.jpg','193017.jpg','200659.jpg','194489.jpg','183418.jpg'] # hair and smile
# selected_sample_id = ['183106.jpg','187412.jpg','186898.jpg']
num_selected_sample = len(selected_sample_id)
# find the selected sample by their id in the test set


# Create a mapping of IDs to indices for easy lookup
id_to_index = {img_id: idx for idx, img_id in enumerate(selected_sample_id)}

# Initialize an empty tensor for selected images (assuming all images have the same shape)
example_img_shape = next(iter(dataloader_test))[0][0].shape  # Get the shape of a single image
selected_img = torch.zeros((num_selected_sample, *example_img_shape))  # Empty tensor to store selected images

# Track which IDs have been matched
found_ids = set()

# Loop through the test dataloader
for b_image, extra_info in tqdm(dataloader_test):
    b_img_id = extra_info[1]  # Assuming this contains the image IDs
    # Find indices of matching IDs in the batch
    matching_indices = [i for i, img_id in enumerate(b_img_id) if img_id in id_to_index]
    
    if matching_indices:  # Only process if there are matches
        for i in matching_indices:
            idx = id_to_index[b_img_id[i]]  # Find the correct index in `selected_img`
            if b_img_id[i] not in found_ids:  # Check if this ID is already processed
                selected_img[idx] = b_image[i]  # Place the image in the correct position
                found_ids.add(b_img_id[i])  # Mark the ID as found
        print(f'Found selected samples: {[b_img_id[i] for i in matching_indices]}')
    
    if len(found_ids) == num_selected_sample:  # Stop only when all IDs are found
        break

print(selected_img.shape)



In [None]:
# visualize the selected samples
fig, axes = plt.subplots(1,num_selected_sample,figsize=(num_selected_sample*4,4))
for i in range(num_selected_sample):
    this_img = selected_img[i]
    # print(this_img.shape)
    axes[i].imshow(np.transpose(inverse_transform(this_img).type(torch.uint8).cpu().squeeze().numpy(), (1, 2, 0)))
    axes[i].axis('off')
    axes[i].set_title(f'{selected_sample_id[i]}')

plt.show()

In [None]:
selected_pact = model.proactBlock(selected_img.to(device))
print(selected_pact.shape)  

selected_xT = sd.reverse_sample_loop(model,selected_img.to(device), model_kwargs={'given_cond_vector':selected_pact})['sample']
print(selected_xT.shape)  


### 3.2 - choose the prototypes and visualize it

In [None]:
selected_p_ind =  top10_hair_smile_records['Smiling'][:2] +  top10_hair_smile_records['Blonde Hair'][:2] + top10_hair_smile_records['Brown Hair'][:2] + top10_hair_smile_records['Black Hair'][:2] 
print(f'selected p index: {selected_p_ind}')

# for each selected p, maximum it's possible value, and then generate the image
highest_activated_record = {}
enhanced_pact_chosen_p_all = []

for real_i,i_p in enumerate(selected_p_ind):
    enhanced_pact_chosen_p = selected_pact.clone()
    enhanced_pact_chosen_p[:,i_p] = max_pact
    enhanced_pact_chosen_p_all.append(enhanced_pact_chosen_p)

enhanced_pact_chosen_p_all = torch.cat(enhanced_pact_chosen_p_all,dim=0)

random_pick_img_all = selected_img.repeat(len(selected_p_ind),1,1,1)
random_pick_xT_all = selected_xT.repeat(len(selected_p_ind),1,1,1)

this_x_0_enhanced_all = sd.sample(model,
                                shape=random_pick_img_all.shape ,
                                noise=random_pick_xT_all,
                                progress=True,
                                model_kwargs={'given_cond_vector':enhanced_pact_chosen_p_all},
                                num_samples=1)


In [None]:
records = {}
num_selected_p =len(selected_p_ind)
for real_i,i_p in enumerate(selected_p_ind):
    this_x_0_enhanced = this_x_0_enhanced_all[real_i*num_selected_sample:(real_i+1)*num_selected_sample]
    this_p_act_all = model.proactBlock(this_x_0_enhanced)
    this_p_act = this_p_act_all[:,i_p]

    # Another way: do abstract before and after enhance p, and use the one that change the most
    this_p_act_all_original = model.proactBlock(selected_img.to(device))
    this_p_act_ori = this_p_act_all_original[:,i_p]

    # get the most activated patch for this image
    most_activated_patches_this =[]
    bounding_boxes_this = []
    for i in range(num_selected_sample):
        most_activated_patches_this_tmp,bounding_boxes_this_tmp = get_most_activated_patch_for_one(this_x_0_enhanced[i].unsqueeze(0),
                                                                 [i_p],
                                                                 model)
        most_activated_patches_this.append(most_activated_patches_this_tmp[0])
        bounding_boxes_this.append(bounding_boxes_this_tmp[0])

    this_p_act_nor_softmax = torch.nn.functional.softmax(this_p_act_all, dim=1)

    records[i_p] = {'ori_img': selected_img,
                                     'enhanced_img':this_x_0_enhanced,
                                     'enhanced_patch':most_activated_patches_this,
                                     'enhanced_b_box':bounding_boxes_this,
                                     'enhanced_p_act':this_p_act,
                                     'enhanced_p_act_nor':None,
                                     'ori_p_act':this_p_act_ori,
                                     }

In [None]:
fontsize = 15
def plot(records,
             selected_p_ind,
             save_pth=None):

    fig, ax = plt.subplots(num_selected_sample, len(selected_p_ind) * 2 + 1, 
                           figsize=(len(selected_p_ind) * 3, num_selected_sample * 1.5), dpi=100)

    col_ind = 0  # Column index starts at 0
    ori_img = records[selected_p_ind[0]]['ori_img']

    # Plot original images in the first column
    for j in range(num_selected_sample):
        this_ori_img = ori_img[j]
        ax[j, col_ind].imshow(np.transpose(inverse_transform(this_ori_img).type(torch.uint8).cpu().squeeze().numpy(), (1, 2, 0)))
        ax[j, col_ind].axis('off')
        if j == 0:
            ax[j, col_ind].set_title(f"x_0", fontsize=fontsize, pad=10)
    col_ind += 1

    # Plot enhanced patches, images, and bounding boxes
    for i_p in selected_p_ind:
        patch = records[i_p]['enhanced_patch']
        g_img = records[i_p]['enhanced_img']
        bounding_boxes = records[i_p]['enhanced_b_box']

        # Plot the enhanced images with bounding boxes
        for j in range(num_selected_sample):
            img = g_img[j]
            ax[j, col_ind].imshow(np.transpose(inverse_transform(img).type(torch.uint8).cpu().squeeze().numpy(), (1, 2, 0)))
            b_box = bounding_boxes[j]

            # Draw bounding box with red lines
            ax[j, col_ind].plot([b_box[0], b_box[2]], [b_box[1], b_box[1]], 'r')
            ax[j, col_ind].plot([b_box[0], b_box[2]], [b_box[3], b_box[3]], 'r')
            ax[j, col_ind].plot([b_box[0], b_box[0]], [b_box[1], b_box[3]], 'r')
            ax[j, col_ind].plot([b_box[2], b_box[2]], [b_box[1], b_box[3]], 'r')

            ax[j, col_ind].set_xlim([0, img.shape[1]])
            ax[j, col_ind].set_ylim([img.shape[2], 0])
            ax[j, col_ind].axis('off')
            ax[j, col_ind].tick_params(labelbottom=False)

            if j == 0:
                ax[j, col_ind].set_title(f"\hat_x(p'_{i_p})", fontsize=fontsize, pad=10)
        col_ind += 1

        # Plot the patches
        for j in range(num_selected_sample):
            p = patch[j].squeeze(0)
            ax[j, col_ind].imshow(np.transpose(inverse_transform(p).type(torch.uint8).cpu().squeeze().numpy(), (1, 2, 0)))
            ax[j, col_ind].axis('off')
            if j == 0:
                ax[j, col_ind].set_title(f"p_{i_p}", fontsize=fontsize, pad=10)
        col_ind += 1

    plt.tight_layout(h_pad=0.5, w_pad=0.0)
    if save_pth is not None:
        plt.savefig(save_pth, bbox_inches='tight')
    plt.show()

In [None]:

# save_pth = './plot_diagnosis.pdf'
save_pth = None
plot(records,
     selected_p_ind,
     save_pth=save_pth)