In [8]:
import os
import glob
import sys
import matplotlib
import matplotlib.pyplot as plt
import cv2
import pandas as pd

def plot_img(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.figure()
    plt.imshow(img)

def analyze_exp(model_name, setting_name, df_path, use_plot_img=False, vis_detail=False):
    # visualize exp. setting
    print('Setting name:', setting_name)

    # visualize metrics
    df_grouping = pd.read_excel(df_path, sheet_name='grouping', index_col=0, engine='openpyxl')
    print(f'group: {df_grouping["<1"].loc["ALL"]:.2f}')

    if vis_detail:
        df_disc = pd.read_excel(df_path, sheet_name='discriminative', index_col=0, engine='openpyxl')
        # print(f'{df_grouping["<1"]}')
        # print(f'key-person:\n{df_disc["<1"]}')
        # print(f'key-action:\n{df_disc["<2"]}')
        # print(f'key-object:\n{df_disc["<3"]}')
        for threshold in ['<1', '<2', '<3']:
            print(f'key-person ({threshold}):{df_disc[threshold].loc["ALL"]:.2f}')

    # visuazlize images
    vis_cl_all_dir = os.path.join(result_analysis_dir, model_name, setting_name, 'cl')
    for vis_cl_id in range(min(10, len(os.listdir(vis_cl_all_dir))-1)):
        vis_cl_dir = os.path.join(vis_cl_all_dir, f'cluster_{vis_cl_id}')
        for vis_img_idx in range(min(4, len(os.listdir(vis_cl_dir)))):
            vis_img_path = os.path.join(vis_cl_dir, os.listdir(vis_cl_dir)[vis_img_idx])
            if use_plot_img:
                plot_img(vis_img_path)

def get_setting_name(cl_sum, cl_cent_type, cl_activate_type, use_debug_model):
    if use_debug_model:
        return f'all_cluster_{cl_sum}_{cl_cent_type}_{cl_activate_type}_debug'
    else:
        return f'all_cluster_{cl_sum}_{cl_cent_type}_{cl_activate_type}'

def get_df_path(model_name, setting_name, use_debug_model):
    if use_debug_model:
        return os.path.join(result_analysis_dir, model_name, setting_name, f'test_cluster_{cl_sum}_scene_{cl_cent_type}_{cl_activate_type}_debug.xlsx')
    else:
        return os.path.join(result_analysis_dir, model_name, setting_name, f'test_cluster_{cl_sum}_scene_{cl_cent_type}_{cl_activate_type}.xlsx')

# define parameters
result_analysis_dir = os.path.join('../result_analysis', 'xai_la_on_volleyball')

use_plot_img = False
# use_plot_img = True

# vis_detail = False
vis_detail = True

use_debug_model = False
# use_debug_model = True

cl_cent_type = 'mean'
# cl_cent_type = 'median'

model_name_list = []
# model_name_list.append('[GR ours rand mask 5_stage2]<2023-10-16_22-26-54>')
# model_name_list.append('[GR ours_stage2]<2023-07-08_08-59-54>')
model_name_list.append('[GR ours recon feat random mask 0 w temp cond_stage2]<2023-11-10_13-26-06>')

# cl_sum_list = [8, 30, 100, 200, 300]
# cl_sum_list = [8, 30, 100]
cl_sum_list = [8]
# cl_sum_list = [100]
# cl_activate_type_list = ['perturbation_original', 'backprop_energy_original', 'random']
# cl_activate_type_list = ['perturbation_original', 'backprop_energy_original']
cl_activate_type_list = ['perturbation_original', 'random']
# cl_activate_type_list = ['perturbation_original']
# cl_activate_type_list = ['backprop_original']

for model_name in model_name_list:
    print('Model name:', model_name)
    for cl_activate_type in cl_activate_type_list:
        for cl_sum in cl_sum_list:
            setting_name = get_setting_name(cl_sum, cl_cent_type, cl_activate_type, use_debug_model)
            df_path = get_df_path(model_name, setting_name, use_debug_model)
            analyze_exp(model_name, setting_name, df_path, use_plot_img=use_plot_img, vis_detail=vis_detail)

Model name: [GR ours recon feat random mask 0 w temp cond_stage2]<2023-11-10_13-26-06>
Setting name: all_cluster_8_mean_perturbation_original
group: 0.25
key-person (<1):0.06
key-person (<2):0.12
key-person (<3):0.19
Setting name: all_cluster_8_mean_random
group: 0.28
key-person (<1):0.08
key-person (<2):0.16
key-person (<3):0.25
