In [1]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
import tabulate
import ansiwrap
import warnings
warnings.filterwarnings('ignore')

from data_preprocessing import granularities
from EDCR_pipeline import n_coarse_grain_classes
from vit_pipeline import vit_model_names, lrs
from context import Plot
from utils import green_text, red_text, blue_text

# Load data

In [2]:
def load_data(data_dir: str) -> dict[str, dict]:
    all_data = {}
    
    for granularity in granularities.values():
        
        suffix = '_coarse' if granularity == 'coarse' else ''
        # Initialize dictionaries to store training and test accuracy data for each model
        train_data = {}
        test_data = {}
        
        test_true = np.load(os.path.join(data_dir, f'test_true{suffix}.npy'))
        
        # Loop through all files in the directory
        for filename in os.listdir(data_dir):
            train_match = re.match(pattern=rf'(.+?)_train_(loss|acc)_lr(.+?)_e(\d+?){suffix}.npy',
                                   string=filename)
            test_match = re.match(pattern=rf'(.+?)_test_pred_lr(.+?)_e(\d+?){suffix}.npy',
                                  string=filename)
        
            if train_match:
                model_name = train_match.group(1)
                metric = train_match.group(2)
                lr_value = float(train_match.group(3))
                num_epochs = int(train_match.group(4)) + 1
        
                # Load the data from the .npy file
                data = np.load(os.path.join(data_dir, filename))
        
                # Store the data in the model_data dictionary
                if model_name not in train_data:
                    train_data[model_name] = {}
                if metric not in train_data[model_name]:
                    train_data[model_name][metric] = {}
                if lr_value not in train_data[model_name][metric]:
                    train_data[model_name][metric][lr_value] = {}
        
                train_data[model_name][metric][lr_value][num_epochs] = data[-1]
            elif test_match:
                model_name = test_match.group(1)
                lr_value = float(test_match.group(2))
                num_epochs = int(test_match.group(3)) + 1
        
                # Load the test data from the .npy file
                test_pred = np.load(os.path.join(data_dir, filename))
        
                # Store the data in the model_test_data dictionary
                if model_name not in test_data:
                    test_data[model_name] = {}
                if lr_value not in test_data[model_name]:
                    test_data[model_name][lr_value] = {}
        
                test_data[model_name][lr_value][num_epochs] = \
                    {'acc': accuracy_score(y_true=test_true, 
                                           y_pred=test_pred), 
                     'cm': confusion_matrix(y_true=test_true, 
                                            y_pred=test_pred),
                     'pre': precision_score(y_true=test_true, 
                                            y_pred=test_pred, 
                                            labels=range(n_coarse_grain_classes), 
                                            average=None),
                     'rec': recall_score(y_true=test_true, 
                                         y_pred=test_pred, 
                                         labels=range(n_coarse_grain_classes), 
                                         average=None),
                     'f1': f1_score(y_true=test_true, 
                                    y_pred=test_pred, 
                                    labels=range(n_coarse_grain_classes), 
                                    average=None)}
                
        all_data[granularity] = {'train': train_data, 'test': test_data}
    
    return all_data
    
    
def plot_train_metrics(all_data: dict[str, dict]):
    for granularity in granularities.values():
        for model_name, model_data in sorted(all_data[granularity]['train'].items()):
            print('\n' + '#'* (100 + len(model_name)))
            print('#'* 50 + f'{model_name}' + '#'* 50)
            print('#'* (100 + len(model_name)) + '\n')
            for metric, metric_data in model_data.items():
                with Plot():
                    plt.title(f"{model_name} {granularity}-grain training {metric} vs. epoch")
                    plt.xlabel('Epoch')
                    plt.ylabel(metric.capitalize())
        
                    for lr_value, lr_data in sorted(metric_data.items()):
                        epochs, data = zip(*sorted(lr_data.items())) # Sort the data based on the number of epochs
                        plt.plot(epochs, data, label=f'lr={lr_value}')
                        plt.xticks(np.arange(min(epochs), max(epochs)+1, 1)) # Set the x-axis ticks to be integers
        
                    plt.legend()
                    plt.grid()
                    
data_dir = 'results/'  # Set the directory where your .npy files are located
all_data = load_data(data_dir=data_dir)
# plot_train_metrics(all_data=all_data)

# Test accuracies

In [3]:
def plot_test_metrics():
    for granularity in granularities.values():
        # Create a dictionary to store accuracy values for each model and learning rate
        accuracy_data = {}

        # Now, create plots for test accuracy vs. epoch for each model
        for model_name, model_data in sorted(all_data[granularity]['test'].items()):
            for lr_value, lr_data in sorted(model_data.items()):
                # Collect the accuracy after the last epoch
                last_epoch = sorted(lr_data.items())[-1][1]
                accuracy = last_epoch['acc']

                # Store the accuracy in the dictionary
                if model_name not in accuracy_data:
                    accuracy_data[model_name] = {}
                accuracy_data[model_name][f'lr={lr_value}'] = accuracy

        # Get a list of all learning rates in the data
        all_learning_rates = sorted(set(lr for model_data in accuracy_data.values() for lr in model_data))

        # Generate the 2-D table with manual headers
        headers = ["Model Name"] + all_learning_rates
        table = []

        for model_name, lr_data in accuracy_data.items():
            max_accuracy = max(lr_data.values())
            row = [model_name]
            for lr in all_learning_rates:
                acc = lr_data.get(lr, "N/A")
                if acc == max_accuracy:
                    acc = green_text(acc)  # Highlight in green for maximum accuracy
                row.append(acc)
            if "N/A" not in row:
                table.append(row)

        # Adding a title to the table
        title = f"Pre-EDCR {granularity.capitalize()} Grain Test Accuracies"

        # Generate the table using tabulate
        table_str = tabulate.tabulate(tabular_data=table, 
                                      headers=headers, 
                                      tablefmt="fancy_grid")

        # Insert the title in the middle of the table
        lines = table_str.split('\n')
        lines.insert(0, title.center(len(lines[0])))
        updated_table = '\n'.join(lines)

        # Print the updated table with the title
        print(updated_table)

        # Calculate and print the maximum accuracy across all models and learning rates for each granularity
        max_accuracy = max(acc for model_data in accuracy_data.values() for acc in model_data.values())
        max_model, max_lr = [(model, lr) for model, lr_data in accuracy_data.items() 
                             for lr, acc in lr_data.items() if acc == max_accuracy][0]

        print(f"Max Accuracy for {granularity.capitalize()} Grain: {max_accuracy} for Model: {max_model}, {max_lr}\n")

plot_test_metrics()

         Pre-EDCR Coarse Grain Test Accuracies         
╒══════════════╤════════════╤════════════╤════════════╕
│ Model Name   │   lr=1e-05 │   lr=1e-06 │   lr=5e-05 │
╞══════════════╪════════════╪════════════╪════════════╡
│ vit_b_16     │   0.809377 │   0.656385 │   [92m0.837138[0m │
├──────────────┼────────────┼────────────┼────────────┤
│ vit_b_32     │   [92m0.770512[0m │   0.630475 │   0.760025 │
├──────────────┼────────────┼────────────┼────────────┤
│ vit_l_16     │   [92m0.84269[0m  │   0.739667 │   0.837138 │
├──────────────┼────────────┼────────────┼────────────┤
│ vit_l_32     │   0.792104 │   0.595312 │   [92m0.807526[0m │
╘══════════════╧════════════╧════════════╧════════════╛
Max Accuracy for Coarse Grain: 0.8426896977174584 for Model: vit_l_16, lr=1e-05

          Pre-EDCR Fine Grain Test Accuracies          
╒══════════════╤════════════╤════════════╤════════════╕
│ Model Name   │   lr=1e-05 │   lr=1e-06 │   lr=5e-05 │
╞══════════════╪════════════╪════════════╪═

# EDCR Results

In [7]:
def gather_EDCR_data() -> dict:
    data = {}  # Create an empty dictionary to store the accuracy data
    
    models_and_lrs_folders = os.listdir(f'figs')
    
    # Iterate through filenames to collect accuracy data
    for filename in models_and_lrs_folders:
        secondary_granularity_match = re.match(
            pattern='main_(fine|coarse)_(.+?)_lr(.+?)_secondary_(fine|coarse)_(.+?)_lr(.+)',
            string=filename
        )
        
        if secondary_granularity_match:
            (   
                match,
                main_granularity,
                main_model_name,
                main_lr,
                secondary_granularity,
                secondary_model_name,
                secondary_lr
            ) = (secondary_granularity_match.group(i) for i in range(7))
            
            main_suffix = '_coarse' if main_granularity == 'coarse' else ''
            test_true = np.load(os.path.join(data_dir, f'test_true{main_suffix}.npy'))
            
            prior_predictions = np.load(os.path.join(data_dir, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
            prior_acc = accuracy_score(y_true=test_true, 
                                       y_pred=prior_predictions)
            
            secondary_suffix = '_coarse' if secondary_granularity == 'coarse' else ''
            post_predictions = np.load(f'figs/{match}/results{secondary_suffix}.npy')
            posterior_acc = accuracy_score(y_true=test_true, 
                                           y_pred=post_predictions)

            # Store accuracy data in the data dictionary
            if main_granularity not in data:
                data[main_granularity] = {}
            if main_model_name not in data[main_granularity]:
                data[main_granularity][main_model_name] = {}
            if secondary_granularity not in data[main_granularity][main_model_name]:
                data[main_granularity][main_model_name][secondary_granularity] = {}
            if secondary_model_name not in data[main_granularity][main_model_name][secondary_granularity]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name] = {}
            if main_lr not in data[main_granularity][main_model_name][secondary_granularity][secondary_model_name]:
                data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr] = {}
                
            data[main_granularity][main_model_name][secondary_granularity][secondary_model_name][main_lr][secondary_lr] = \
                {'prior': prior_acc, 'post': posterior_acc}
        
        else:
            no_secondary_granularity_match = re.match(
            pattern='main_(fine|coarse)_(.+)_lr(.+)_secondary_(.+)_lr(.+)',
            string=filename)
            
            if no_secondary_granularity_match:
                
                (
                    match,
                    main_granularity,
                    main_model_name,
                    main_lr,
                    secondary_model_name,
                    secondary_lr 
                ) = (no_secondary_granularity_match.group(i) for i in range(6))
                
                main_suffix = '_coarse' if main_granularity == 'coarse' else ''
                test_true = np.load(os.path.join(data_dir, f'test_true{main_suffix}.npy'))
                
                prior_predictions = np.load(os.path.join(data_dir, rf'{main_model_name}_test_pred_lr{main_lr}_e3{main_suffix}.npy'))
                prior_acc = accuracy_score(y_true=test_true, 
                                           y_pred=prior_predictions)
                
                try:
                    post_predictions = np.load(f'figs/{match}/results.npy')
                except FileNotFoundError:
                    post_predictions = np.load(f'figs/{match}/results_coarse.npy')
                    
                posterior_acc = accuracy_score(y_true=test_true, 
                                               y_pred=post_predictions)
    
                if main_granularity not in data:
                    data[main_granularity] = {}
                if main_model_name not in data[main_granularity]:
                    data[main_granularity][main_model_name] = {}
                if secondary_model_name not in data[main_granularity][main_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name] = {}
                if main_lr not in data[main_granularity][main_model_name][secondary_model_name]:
                    data[main_granularity][main_model_name][secondary_model_name][main_lr] = {}
                
                data[main_granularity][main_model_name][secondary_model_name][main_lr][secondary_lr] = \
                    {'prior': prior_acc, 'post': posterior_acc}
                
    return data

In [37]:
def get_row_addition(secondary_lr: float, 
                     curr_post: float, 
                     curr_diff: float,
                     max_accuracy = None) -> str:
    return (f"{secondary_lr}: "
            + (blue_text(curr_post) if max_accuracy is not None and curr_post == max_accuracy else str(curr_post)) + '%, ' + (green_text('+') if curr_diff > 0 else '') + (green_text(f'{curr_diff}%') if curr_diff > 0 
                                                     else red_text(f'{curr_diff}%')) + '\n')

def get_row_data(main_lr_data: dict,
                 secondary_lr: float):
    curr_data = main_lr_data[secondary_lr]
    curr_post = round(curr_data['post'] * 100, 1)
    curr_prior = round(curr_data['prior'] * 100, 1)
    curr_diff = round(curr_post - curr_prior, 1)
    row_additions = get_row_addition(secondary_lr=secondary_lr, 
                                     curr_post=curr_post, 
                                     curr_diff=curr_diff)
    
    return row_additions, curr_prior


def highlight_max(table_data: list):
    # Find the maximum accuracy value
    max_accuracy = 0
    cell_pattern = ('(.+): (.+)% \([+]*[-]*(.+)%\)\n'
                    '(.+): (.+)% \([+]*[-]*(.+)%\)\n'
                    '(.+): (.+)% \([+]*[-]*(.+)%\)')
    
    for row in table_data[1:]:  # Skip the header row
        for cell in row[1:]:  # Skip the first element in each row (model names)
            match = re.match(
                    pattern=cell_pattern,
                    string=ansiwrap.strip_color(cell))
            
            if match:
                post_accuracies = [float(match.group(2 * i)) for i in range (1, 4)]
                for post_accuracy in post_accuracies:
                    if post_accuracy > max_accuracy:
                        max_accuracy = post_accuracy

    # Highlight the maximum accuracy value in blue
    for i, row in enumerate(table_data):
        if i > 0:
            for j, cell in enumerate(row):
                if j > 0:  # Skip the first element in each row (model names)
                    match = re.match(
                        pattern=cell_pattern,
                        string=ansiwrap.strip_color(cell))
                    
                    if match:
                        new_cell = ''
                        for secondary_lr, curr_post, curr_diff in [(float(match.group(3 * i + j)) for j in range(1, 4)) 
                                                                   for i in range(3)]:
                            new_cell += get_row_addition(secondary_lr=secondary_lr,
                                                         curr_post=curr_post,
                                                         curr_diff=curr_diff,
                                                         max_accuracy=max_accuracy)
                        
                        table_data[i][j] = new_cell
    
    return table_data


def print_one_secondary_granularity(main_model_data: dict,
                                    k: str,
                                    main_granularity: str,
                                    main_model_name: str):

    secondary_granularity_data = main_model_data[k]
    main_learning_rates = sorted(secondary_granularity_data[list(secondary_granularity_data.keys())[0]].keys())
    header = [''] + main_learning_rates
    table_data = [header]
    priors = {}

    for secondary_model_name in sorted(secondary_granularity_data.keys()):
        secondary_model_data = secondary_granularity_data[secondary_model_name]
        row = [secondary_model_name]
        
        for main_lr in sorted(secondary_model_data.keys()):
            main_lr_data = secondary_model_data[main_lr]
            row_add = ''
            
            for secondary_lr in sorted(main_lr_data.keys()):
                row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                        secondary_lr=secondary_lr)
                row_add += row_addition
                priors[main_lr] = curr_prior
    
            row += [row_add]
        table_data += [row]
    
    table_data[0] = [''] + [f'{main_lr} ({priors[main_lr]}%)' for main_lr in main_learning_rates]
    # Rest of your code to create and print the table remains unchanged
    table = tabulate.tabulate(
        tabular_data=highlight_max(table_data), 
        headers='firstrow', 
        tablefmt='grid'
    )
    print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name}, "
          f"secondary granularity: {k}")
    print(table)
    print("\n")


def print_two_secondary_granularities(main_model_data: dict,
                                      table_data: list,
                                      k: str,
                                      main_granularity: str,
                                      main_model_name: str):
    main_learning_rates = sorted(lrs)
    
    priors = {}
    
    # Initialize the table_data with header if it's empty
    if len(table_data) == 0:
        header = [''] + main_learning_rates
        table_data += [header]
        
    secondary_model_data = main_model_data[k]
    row = [k]
    
    for main_lr in sorted(secondary_model_data.keys()):
        main_lr_data = secondary_model_data[main_lr]
        row_add = ''
        
        for secondary_lr in sorted(main_lr_data.keys()):
            row_addition, curr_prior = get_row_data(main_lr_data=main_lr_data,
                                                    secondary_lr=secondary_lr)
            row_add += row_addition
            priors[main_lr] = curr_prior
    
        row += [row_add]

    table_data += [row]
    
    # Modify the generated table data to highlight the cell with the maximal accuracy in blue
    
    if len(table_data) == len(main_learning_rates) + 1:
        
        table_data[0] = [''] + [f'{main_lr} ({priors[str(main_lr)]}%)' for main_lr in main_learning_rates]
        
        # Create the table using tabulate
        table = tabulate.tabulate(
            tabular_data=highlight_max(table_data),
            headers='firstrow',
            tablefmt='grid'
        )
        
        # Print the main model name and the corresponding table
        print(f"Main model: {main_granularity.capitalize()}-grain {main_model_name} "
              f"with both fine and coarse grain secondary models")
        print(table)
        print("\n")
    
    return table_data
    
    
def print_EDCR_tables():
    data = gather_EDCR_data()
    
    for main_granularity in sorted(data.keys()):
        
        print('#' * 40 + f' Main granularity: {main_granularity} ' + '#' * 40 + '\n' + '#' * 104 + '\n')
        main_granularity_data = data[main_granularity]
        
        for main_model_name in sorted(main_granularity_data.keys()):
            main_model_data = main_granularity_data[main_model_name]
            table_data = []

            for k in (sorted(set(main_model_data.keys()).intersection(granularities.values())) + 
                      sorted(set(main_model_data.keys()).intersection(vit_model_names))):
            
                if k in granularities.values():
                    print_one_secondary_granularity(main_model_data=main_model_data,
                        k=k,
                        main_granularity=main_granularity,
                        main_model_name=main_model_name)
                else:
                    table_data = print_two_secondary_granularities(main_model_data=main_model_data,
                                                      table_data=table_data,
                                                      k=k,
                                                      main_granularity=main_granularity,
                                                      main_model_name=main_model_name)
                    
            print('#' * 100)

print_EDCR_tables()

######################################## Main granularity: coarse ########################################
########################################################################################################

Main model: Coarse-grain vit_b_16, secondary granularity: coarse
+----------+---------------------+----------------------+---------------------+
|          | 1e-05 (80.9%)       | 1e-06 (65.6%)        | 5e-05 (83.7%)       |
| vit_b_32 | 1e-05: 81.0%, [92m+[0m[92m0.1%[0m | 1e-05: 73.0%, [92m+[0m[92m7.4%[0m  | 1e-05: 82.4%, [91m-1.3%[0m |
|          | 1e-06: 80.9%, [91m0.0%[0m  | 1e-06: 65.3%, [91m-0.3%[0m  | 1e-06: 83.7%, [91m0.0%[0m  |
|          | 5e-05: 81.2%, [92m+[0m[92m0.3%[0m | 5e-05: 72.5%, [92m+[0m[92m6.9%[0m  | 5e-05: 83.7%, [91m0.0%[0m  |
+----------+---------------------+----------------------+---------------------+
| vit_l_16 | 1e-05: 83.0%, [92m+[0m[92m2.1%[0m | 1e-05: 76.8%, [92m+[0m[92m11.2%[0m | 1e-05: 83.8%, [92m+[0m[92m0.