In [14]:
import os
import re
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import tabulate
import warnings
warnings.filterwarnings('ignore')

import data_preprocessing
import EDCR_pipeline
import vit_pipeline
import utils

# EDCR Results

In [19]:
def gather_EDCR_data() -> dict:
    data = {} 
    
    # Iterate through filenames to collect accuracy data
    for filename in os.listdir(EDCR_pipeline.figs_folder):
        secondary_granularity_match = re.match(
            pattern='main_(fine|coarse)_(.+?)_lr(.+?)_secondary_(fine|coarse)_(.+?)_lr(.+)',
            string=filename
        )
        
        if secondary_granularity_match:
            (   match,
                main_granularity,
                main_model_name,
                main_lr,
                secondary_granularity,
                secondary_model_name,
                secondary_lr
            ) = (secondary_granularity_match.group(i) for i in range(7))
            
            main_suffix = '_coarse' if main_granularity == 'coarse' else ''
            test_true = np.load(os.path.join(EDCR_pipeline.data_folder, f'test_true{main_suffix}.npy'))
            
            prior_predictions = np.load(os.path.join(EDCR_pipeline.data_folder, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
            
            
            secondary_suffix = '_coarse' if secondary_granularity == 'coarse' else ''
            post_predictions = np.load(f'figs/{match}/results{secondary_suffix}.npy')

            # Store accuracy data in the data dictionary
            if main_granularity not in data:
                data[main_granularity] = {}
            if main_model_name not in data[main_granularity]:
                data[main_granularity][main_model_name] = {}
            if secondary_granularity not in data[main_granularity][main_model_name]:
                data[main_granularity][main_model_name][secondary_granularity] = {}
            if secondary_model_name not in data[main_granularity][main_model_name][secondary_granularity]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name] = {}
            if main_lr not in data[main_granularity][main_model_name][secondary_granularity][secondary_model_name]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr] = {}
                
            data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr][secondary_lr] = \
                {prior_or_post: ({'acc': accuracy_score(y_true=test_true, 
                                                       y_pred=(prior_predictions 
                                                       if prior_or_post == 'prior' else post_predictions))} | 
                                {metric_name: metric_value(y_true=test_true, 
                                                           y_pred=(prior_predictions 
                                                                   if prior_or_post == 'prior' else post_predictions), 
                                            average='micro') 
                                 for metric_name, metric_value in {'pre': precision_score, 'rec': recall_score, 'f1': f1_score}.items()})
                 for prior_or_post in ['prior', 'post']}
        else:
            no_secondary_granularity_match = re.match(pattern='main_(fine|coarse)_(.+)_lr(.+)_secondary_(.+)_lr(.+)',
                                                      string=filename)
            
            if no_secondary_granularity_match:
                
                (match,
                 main_granularity,
                 main_model_name,
                 main_lr,
                 secondary_model_name,
                 secondary_lr 
                ) = (no_secondary_granularity_match.group(i) for i in range(6))
                
                classes = EDCR_pipeline.get_classes(granularity=main_granularity)
                
                main_suffix = '_coarse' if main_granularity == 'coarse' else ''
                test_true = np.load(os.path.join(EDCR_pipeline.data_folder, f'test_true{main_suffix}.npy'))
                
                prior_predictions = np.load(os.path.join(EDCR_pipeline.data_folder, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
                
                try:
                    post_predictions = np.load(f'figs/{match}/results.npy')
                except FileNotFoundError:
                    post_predictions = np.load(f'figs/{match}/results_coarse.npy')
                    
                if main_granularity not in data:
                    data[main_granularity] = {}
                if main_model_name not in data[main_granularity]:
                    data[main_granularity][main_model_name] = {}
                if secondary_model_name not in data[main_granularity][main_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name] = {}
                if main_lr not in data[main_granularity][main_model_name][secondary_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name][main_lr] = {}
                
                data[main_granularity][main_model_name][secondary_model_name][main_lr][secondary_lr] = \
                    {prior_or_post: {'acc': accuracy_score(y_true=test_true, 
                                                           y_pred=(prior_predictions if prior_or_post == 'prior' else post_predictions))} | 
                                {metric.__name__: metric(y_true=test_true, 
                                            y_pred=(prior_predictions if prior_or_post == 'prior' else post_predictions), 
                                            labels=range(len(classes)), 
                                            average=None) 
                                 for metric in [precision_score, recall_score, f1_score]}
                 for prior_or_post in ['prior', 'post']}
                
    return data


def get_row_addition(secondary_lr: float, 
                     curr_data: dict,
                     max_accuracy: float = None) -> (str, float):
    curr_prior_data = curr_data['prior']
    curr_post_data = curr_data['post']
    
    curr_prior_accuracy = round(curr_prior_data['acc'] * 100, 1)
    curr_post_accuracy = round(curr_post_data['acc'] * 100, 1)
    curr_accuracy_diff = round(curr_post_accuracy - curr_prior_accuracy, 1)
    
    post_acc_str = (utils.blue_text(curr_post_accuracy) 
                    if max_accuracy is not None and abs(curr_post_accuracy - max_accuracy) < 1e-5
                    else str(curr_post_accuracy))
    acc_diff_sign_str = (utils.green_text('+') if curr_accuracy_diff > 0 else '')

    curr_prior_average_precision = round(curr_prior_data['pre'] * 100, 1)
    curr_post_average_precision = round(curr_post_data['pre'] * 100, 1)
    
    row_addition = (f"{secondary_lr}: acc: {post_acc_str}%, ({acc_diff_sign_str}"  + 
                    (utils.green_text(f'{curr_accuracy_diff}%') if curr_accuracy_diff > 0 
                     else utils.red_text(f'{curr_accuracy_diff}%')) + '), ' 
                    f'micro-pre: {curr_post_average_precision}%' + '\n')
    
    return row_addition, curr_prior_accuracy


def get_row_data(main_lr_data: dict,
                 secondary_lr: float) -> (str, float):
    curr_data = main_lr_data[secondary_lr]
    
    row_additions, curr_prior_acc = get_row_addition(secondary_lr=secondary_lr, 
                                                     curr_data=curr_data)
    
    return row_additions, curr_prior_acc


def print_one_secondary_granularity(main_model_data: dict,
                                    k: str,
                                    main_granularity: str,
                                    main_model_name: str):

    secondary_granularity_data = main_model_data[k]
    main_learning_rates = sorted(secondary_granularity_data[list(secondary_granularity_data.keys())[0]].keys())
    header = [''] + main_learning_rates
    table_data = [header]
    priors = {}

    for secondary_model_name in sorted(secondary_granularity_data.keys()):
        secondary_model_data = secondary_granularity_data[secondary_model_name]
        row = [secondary_model_name]
        
        for main_lr in sorted(secondary_model_data.keys()):
            main_lr_data = secondary_model_data[main_lr]
            row_add = ''
            
            for secondary_lr in sorted(main_lr_data.keys()):
                row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                        secondary_lr=secondary_lr)
                row_add += row_addition
                priors[main_lr] = curr_prior
                
            row += [row_add]
        table_data += [row]
    
    table_data[0] = [''] + [f'{main_lr} ({priors[main_lr]}%)' for main_lr in main_learning_rates]
    
    # Rest of your code to create and print the table remains unchanged
    table = tabulate.tabulate(
        tabular_data=table_data, 
        headers='firstrow', 
        tablefmt='grid'
    )
    print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name}, "
          f"secondary granularity: {k}")
    print(table)
    print("\n")


def print_two_secondary_granularities(main_model_data: dict,
                                      two_secondary_table_data: list,
                                      k: str,
                                      main_granularity: str,
                                      main_model_name: str):
    main_learning_rates = sorted(vit_pipeline.lrs)
    
    priors = {}
    
    # Initialize the table_data with header if it's empty
    if len(two_secondary_table_data) == 0:
        header = [''] + main_learning_rates
        two_secondary_table_data += [header]
        
    secondary_model_data = main_model_data[k]
    row = [k]
    
    for main_lr in sorted(secondary_model_data.keys()):
        main_lr_data = secondary_model_data[main_lr]
        row_add = ''
        
        for secondary_lr in sorted(main_lr_data.keys()):
            row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                    secondary_lr=secondary_lr)
            row_add += row_addition
            priors[main_lr] = curr_prior
    
        row += [row_add]

    two_secondary_table_data += [row]
    
    # Modify the generated table data to highlight the cell with the maximal accuracy in blue
    
    if len(two_secondary_table_data) == len(main_learning_rates) + 1:
        
        two_secondary_table_data[0] = [''] + [f'{main_lr} ({priors[str(main_lr)]}%)' for main_lr in main_learning_rates]
        
        # Create the table using tabulate
        table = tabulate.tabulate(
            tabular_data=two_secondary_table_data,
            headers='firstrow',
            tablefmt='grid'
        )
        
        # Print the main model name and the corresponding table
        print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name} "
              f"with both fine and coarse grain secondary models")
        print(table)
        print("\n")
    else:
        return two_secondary_table_data


def print_EDCR_tables():
    data = gather_EDCR_data()
    
    for main_granularity in sorted(data.keys()):
        
        print('#' * 40 + f' Main granularity: {main_granularity} ' + '#' * 40 + '\n' + '#' * 104 + '\n')
        main_granularity_data = data[main_granularity]
        
        for main_model_name in sorted(main_granularity_data.keys()):
            main_model_data = main_granularity_data[main_model_name]
            two_secondary_table_data = []

            for k in (sorted(set(main_model_data.keys()).intersection(data_preprocessing.granularities.values())) + 
                      sorted(set(main_model_data.keys()).intersection(vit_pipeline.vit_model_names))):
            
                if k in data_preprocessing.granularities.values():
                    print_one_secondary_granularity(main_model_data=main_model_data,
                        k=k,
                        main_granularity=main_granularity,
                        main_model_name=main_model_name)
                else:
                    two_secondary_table_data = print_two_secondary_granularities(main_model_data=main_model_data,
                                                      two_secondary_table_data=two_secondary_table_data,
                                                      k=k,
                                                      main_granularity=main_granularity,
                                                      main_model_name=main_model_name)
            print('#' * 100)

print_EDCR_tables()

######################################## Main granularity: coarse ########################################
########################################################################################################

Main model: Coarse-grain vit_b_16, secondary granularity: coarse
+----------+----------------------------------------------+-----------------------------------------------+----------------------------------------------+
|          | 1e-05 (80.9%)                                | 1e-06 (65.6%)                                 | 5e-05 (83.7%)                                |
| vit_b_32 | 1e-05: acc: 81.0%, ([92m+[0m[92m0.1%[0m), micro-pre: 81.0% | 1e-05: acc: 73.0%, ([92m+[0m[92m7.4%[0m), micro-pre: 73.0%  | 1e-05: acc: 82.4%, ([91m-1.3%[0m), micro-pre: 82.4% |
|          | 1e-06: acc: 80.9%, ([91m0.0%[0m), micro-pre: 80.9%  | 1e-06: acc: 65.3%, ([91m-0.3%[0m), micro-pre: 65.3%  | 1e-06: acc: 83.7%, ([91m0.0%[0m), micro-pre: 83.7%  |
|          | 5e-05: acc: 81.2%

KeyError: 'pre'