In [1]:
import os
import re
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
import tabulate
import ansiwrap
import warnings
warnings.filterwarnings('ignore')

import data_preprocessing
import EDCR_pipeline
import vit_pipeline
import utils

# EDCR Results

In [13]:
def gather_EDCR_data() -> dict:
    data = {} 
    
    # Iterate through filenames to collect accuracy data
    for filename in os.listdir(EDCR_pipeline.figs_folder):
        secondary_granularity_match = re.match(
            pattern='main_(fine|coarse)_(.+?)_lr(.+?)_secondary_(fine|coarse)_(.+?)_lr(.+)',
            string=filename
        )
        
        if secondary_granularity_match:
            (   match,
                main_granularity,
                main_model_name,
                main_lr,
                secondary_granularity,
                secondary_model_name,
                secondary_lr
            ) = (secondary_granularity_match.group(i) for i in range(7))
            
            main_suffix = '_coarse' if main_granularity == 'coarse' else ''
            test_true = np.load(os.path.join(EDCR_pipeline.data_folder, f'test_true{main_suffix}.npy'))
            
            prior_predictions = np.load(os.path.join(EDCR_pipeline.data_folder, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
            prior_acc = accuracy_score(y_true=test_true, 
                                       y_pred=prior_predictions)
            
            secondary_suffix = '_coarse' if secondary_granularity == 'coarse' else ''
            post_predictions = np.load(f'figs/{match}/results{secondary_suffix}.npy')
            posterior_acc = accuracy_score(y_true=test_true, 
                                           y_pred=post_predictions)

            # Store accuracy data in the data dictionary
            if main_granularity not in data:
                data[main_granularity] = {}
            if main_model_name not in data[main_granularity]:
                data[main_granularity][main_model_name] = {}
            if secondary_granularity not in data[main_granularity][main_model_name]:
                data[main_granularity][main_model_name][secondary_granularity] = {}
            if secondary_model_name not in data[main_granularity][main_model_name][secondary_granularity]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name] = {}
            if main_lr not in data[main_granularity][main_model_name][secondary_granularity][secondary_model_name]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr] = {}
                
            data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr][secondary_lr] = \
                {'prior': prior_acc, 'post': posterior_acc}
        
        else:
            no_secondary_granularity_match = re.match(pattern='main_(fine|coarse)_(.+)_lr(.+)_secondary_(.+)_lr(.+)',
                                                      string=filename)
            
            if no_secondary_granularity_match:
                
                (
                    match,
                    main_granularity,
                    main_model_name,
                    main_lr,
                    secondary_model_name,
                    secondary_lr 
                ) = (no_secondary_granularity_match.group(i) for i in range(6))
                
                main_suffix = '_coarse' if main_granularity == 'coarse' else ''
                test_true = np.load(os.path.join(EDCR_pipeline.data_folder, f'test_true{main_suffix}.npy'))
                
                prior_predictions = np.load(os.path.join(EDCR_pipeline.data_folder, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
                prior_acc = accuracy_score(y_true=test_true, 
                                           y_pred=prior_predictions)
                
                try:
                    post_predictions = np.load(f'figs/{match}/results.npy')
                except FileNotFoundError:
                    post_predictions = np.load(f'figs/{match}/results_coarse.npy')
                    
                posterior_acc = accuracy_score(y_true=test_true, 
                                               y_pred=post_predictions)
    
                if main_granularity not in data:
                    data[main_granularity] = {}
                if main_model_name not in data[main_granularity]:
                    data[main_granularity][main_model_name] = {}
                if secondary_model_name not in data[main_granularity][main_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name] = {}
                if main_lr not in data[main_granularity][main_model_name][secondary_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name][main_lr] = {}
                
                data[main_granularity][main_model_name][secondary_model_name][main_lr][secondary_lr] = \
                    {'prior': prior_acc, 'post': posterior_acc}
                
    return data


def get_row_addition(secondary_lr: float, 
                     curr_post: float, 
                     curr_diff: float,
                     max_accuracy: float = None) -> str:
    return (f"{secondary_lr}: "
            + (utils.blue_text(curr_post) if max_accuracy is not None and abs(curr_post - max_accuracy) < 1e-5
               else str(curr_post)) + '%, ' + 
            (utils.green_text('+') if curr_diff > 0 else '') + (utils.green_text(f'{curr_diff}%') if curr_diff > 0 
                                                     else utils.red_text(f'{curr_diff}%')) + '\n')


def get_row_data(main_lr_data: dict,
                 secondary_lr: float):
    curr_data = main_lr_data[secondary_lr]
    curr_post = round(curr_data['post'] * 100, 1)
    curr_prior = round(curr_data['prior'] * 100, 1)
    curr_diff = round(curr_post - curr_prior, 1)
    row_additions = get_row_addition(secondary_lr=secondary_lr, 
                                     curr_post=curr_post, 
                                     curr_diff=curr_diff)
    
    return row_additions, curr_prior


def highlight_max(table_data: list):
    # Find the maximum accuracy value
    max_accuracy = 0.0
    cell_pattern = ('(.+): (.+)%, [+]*(.+)%\n'
                    '(.+): (.+)%, [+]*(.+)%\n'
                    '(.+): (.+)%, [+]*(.+)%')
    
    for row in table_data[1:]:  # Skip the header row
        for cell in row[1:]:  # Skip the first element in each row (model names)
            match = re.match(
                    pattern=cell_pattern,
                    string=ansiwrap.strip_color(cell))
            
            if match:
                post_accuracies = [float(match.group(2 * i)) for i in range (1, 4)]
                for post_accuracy in post_accuracies:
                    if post_accuracy > max_accuracy:
                        max_accuracy = post_accuracy

    # Highlight the maximum accuracy value in blue
    for i, row in enumerate(table_data):
        if i > 0:
            for j, cell in enumerate(row):
                if j > 0:  # Skip the first element in each row (model names)
                    match = re.match(
                        pattern=cell_pattern,
                        string=ansiwrap.strip_color(cell))
                    
                    if match:
                        new_cell = ''
                        for secondary_lr, curr_post, curr_diff in [(float(match.group(3 * i + j)) for j in range(1, 4)) 
                                                                   for i in range(3)]:
                            new_cell += get_row_addition(secondary_lr=secondary_lr,
                                                         curr_post=curr_post,
                                                         curr_diff=curr_diff,
                                                         max_accuracy=max_accuracy)
                        
                        table_data[i][j] = new_cell
    
    return table_data


def print_one_secondary_granularity(main_model_data: dict,
                                    k: str,
                                    main_granularity: str,
                                    main_model_name: str):

    secondary_granularity_data = main_model_data[k]
    main_learning_rates = sorted(secondary_granularity_data[list(secondary_granularity_data.keys())[0]].keys())
    header = [''] + main_learning_rates
    table_data = [header]
    priors = {}

    for secondary_model_name in sorted(secondary_granularity_data.keys()):
        secondary_model_data = secondary_granularity_data[secondary_model_name]
        row = [secondary_model_name]
        
        for main_lr in sorted(secondary_model_data.keys()):
            main_lr_data = secondary_model_data[main_lr]
            row_add = ''
            
            for secondary_lr in sorted(main_lr_data.keys()):
                row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                        secondary_lr=secondary_lr)
                row_add += row_addition
                priors[main_lr] = curr_prior
                
                    
            row += [row_add]
        table_data += [row]
    
    table_data[0] = [''] + [f'{main_lr} ({priors[main_lr]}%)' for main_lr in main_learning_rates]
    # Rest of your code to create and print the table remains unchanged
    table = tabulate.tabulate(
        tabular_data=table_data, 
        headers='firstrow', 
        tablefmt='grid'
    )
    print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name}, "
          f"secondary granularity: {k}")
    print(table)
    print("\n")


def print_two_secondary_granularities(main_model_data: dict,
                                      two_secondary_table_data: list,
                                      k: str,
                                      main_granularity: str,
                                      main_model_name: str):
    main_learning_rates = sorted(vit_pipeline.lrs)
    
    priors = {}
    
    # Initialize the table_data with header if it's empty
    if len(two_secondary_table_data) == 0:
        header = [''] + main_learning_rates
        two_secondary_table_data += [header]
        
    secondary_model_data = main_model_data[k]
    row = [k]
    
    for main_lr in sorted(secondary_model_data.keys()):
        main_lr_data = secondary_model_data[main_lr]
        row_add = ''
        
        for secondary_lr in sorted(main_lr_data.keys()):
            row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                    secondary_lr=secondary_lr)
            row_add += row_addition
            priors[main_lr] = curr_prior
    
        row += [row_add]

    two_secondary_table_data += [row]
    
    # Modify the generated table data to highlight the cell with the maximal accuracy in blue
    
    if len(two_secondary_table_data) == len(main_learning_rates) + 1:
        
        two_secondary_table_data[0] = [''] + [f'{main_lr} ({priors[str(main_lr)]}%)' for main_lr in main_learning_rates]
        
        # Create the table using tabulate
        table = tabulate.tabulate(
            tabular_data=two_secondary_table_data,
            headers='firstrow',
            tablefmt='grid'
        )
        
        # Print the main model name and the corresponding table
        print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name} "
              f"with both fine and coarse grain secondary models")
        print(table)
        print("\n")
    else:
        return two_secondary_table_data


def print_EDCR_tables():
    data = gather_EDCR_data()
    
    for main_granularity in sorted(data.keys()):
        
        print('#' * 40 + f' Main granularity: {main_granularity} ' + '#' * 40 + '\n' + '#' * 104 + '\n')
        main_granularity_data = data[main_granularity]
        
        for main_model_name in sorted(main_granularity_data.keys()):
            main_model_data = main_granularity_data[main_model_name]
            two_secondary_table_data = []

            for k in (sorted(set(main_model_data.keys()).intersection(data_preprocessing.granularities.values())) + 
                      sorted(set(main_model_data.keys()).intersection(vit_pipeline.vit_model_names))):
            
                if k in data_preprocessing.granularities.values():
                    print_one_secondary_granularity(main_model_data=main_model_data,
                        k=k,
                        main_granularity=main_granularity,
                        main_model_name=main_model_name)
                else:
                    two_secondary_table_data = print_two_secondary_granularities(main_model_data=main_model_data,
                                                      two_secondary_table_data=two_secondary_table_data,
                                                      k=k,
                                                      main_granularity=main_granularity,
                                                      main_model_name=main_model_name)
            print('#' * 100)

print_EDCR_tables()

######################################## Main granularity: coarse ########################################
########################################################################################################

Main model: Coarse-grain vit_b_16, secondary granularity: coarse
+----------+---------------------+----------------------+---------------------+
|          | 1e-05 (80.9%)       | 1e-06 (65.6%)        | 5e-05 (83.7%)       |
| vit_b_32 | 1e-05: 81.0%, [92m+[0m[92m0.1%[0m | 1e-05: 73.0%, [92m+[0m[92m7.4%[0m  | 1e-05: 82.4%, [91m-1.3%[0m |
|          | 1e-06: 80.9%, [91m0.0%[0m  | 1e-06: 65.3%, [91m-0.3%[0m  | 1e-06: 83.7%, [91m0.0%[0m  |
|          | 5e-05: 81.2%, [92m+[0m[92m0.3%[0m | 5e-05: 72.5%, [92m+[0m[92m6.9%[0m  | 5e-05: 83.7%, [91m0.0%[0m  |
+----------+---------------------+----------------------+---------------------+
| vit_l_16 | 1e-05: 83.0%, [92m+[0m[92m2.1%[0m | 1e-05: 76.8%, [92m+[0m[92m11.2%[0m | 1e-05: 83.8%, [92m+[0m[92m0.