In [None]:
import sys
sys.path.append('../..')

In [None]:
import os
import torch
import seaborn as sns
from tools import data_tools
import matplotlib.pyplot as plt


config = data_tools.read_config(
    'corruption_d_matrix_config.yaml')

##################
##################
##################

model_name = config["model_name"]='resnet34_custom'
match_dataset_name = config["match_dataset_name"]
corrupted_dataset_name = config["corrupted_dataset_name"]
model_seed = config["model_seed"]
data_path = config["data_path"]
magnitudes = config["magnitudes"]
temperatures = config["temperatures"]
batch_size = config["batch_size"]
rs = config["rs"]
seeds = [1]  # config["seeds"]
lbds = config["lbds"]
lr = config["lr"]
epochs = config["epochs"]
batch_size = config["batch_size"]
corruptions = config["corruptions"] = ['brightness']
intensities = config["intensities"] = [1]

# print config one by one
for key, value in config.items():
    print(key, value)

dest_folder = f"{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}"
device = torch.device("cpu")


In [None]:
def prepare_D_matrix(params):
    params = torch.tril(params, diagonal=-1)
    params = params + params.T
    params = params.abs()
    params = params / params.norm()
    if params.device == torch.device('cpu'):
        params = params.detach().cpu().numpy()
    return params

for seed in seeds:
    for r in rs:
        for lbd in lbds:
            fig, axs = plt.subplots(2, 2, figsize=(10, 10))
            dest_folder_seed = os.path.join(
                dest_folder, f"seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}")
            D_matrix = prepare_D_matrix(torch.load('/'.join((dest_folder_seed, 'D_matrix.pt')), map_location=device))
            loss_history = torch.load('/'.join((dest_folder_seed,
                                                'training_D_matrix_loss_history.pt')))
            loss_history_pos = torch.load('/'.join((dest_folder_seed,
                                                    'training_D_matrix_loss_history_pos.pt')))
            loss_history_neg = torch.load('/'.join((dest_folder_seed,
                                                    'training_D_matrix_loss_history_neg.pt')))
            auc_history = torch.load('/'.join((dest_folder_seed,
                                            'training_D_matrix_auc_history.pt')))
            fpr_at_95_tpr_history = torch.load('/'.join((dest_folder_seed,
                                                        'training_D_matrix_fpr_at_95_tpr_history.pt')))
            # plot heatmap of D matrix using seaborn in the first subplot
            ax = axs[0, 0]
            ax.set_title('D matrix')
            sns.heatmap(D_matrix, ax=ax, cmap='Greens', vmin=0, vmax=1)
            # plot loss history in the second subplot
            ax = axs[0, 1]
            ax.set_title('Loss history')
            ax.plot(loss_history)
            ax.plot(loss_history_pos)
            ax.plot(loss_history_neg)
            ax.legend(['loss', 'loss_pos', 'loss_neg'])
            # plot auc history in the third subplot
            ax = axs[1, 0]
            ax.set_title('AUC history')
            ax.plot(auc_history)
            # plot fpr at 95 tpr history in the fourth subplot
            ax = axs[1, 1]
            ax.set_title('FPR at 95 TPR history')
            ax.plot(fpr_at_95_tpr_history)
            # show the figure
            fig.suptitle(f"seed {seed}, r {r}, lbd {lbd}")
            # acivate grid for all subplots in the figure but the first one
            for ax in fig.axes[1:]:
                ax.grid()
            plt.show()
    
# close all the figures
plt.close('all')
        
        


In [None]:
for seed in seeds:
    # create a figure with 2 subplots
    for corruption in corruptions:
        for intensity in intensities:
            for r in rs:
                for lbd in lbds:
                    max_auc = -float('inf')
                    fpr_max_auc = None
                    tpr_max_auc = None
                    temperature_max_auc = None
                    magnitude_max_auc = None

                    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
                    dest_folder_seed = os.path.join(
                        dest_folder, f"seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}/{corruption}_{intensity}")
                    for magnitude in magnitudes:
                        magnitude_folder = os.path.join(
                            dest_folder_seed, f"magnitude_{magnitude}")
                        for temperature in temperatures:
                            temperature_folder = os.path.join(
                                magnitude_folder, f"temperature_{temperature}")
                            D_fprs_val = torch.load(
                                '/'.join([temperature_folder, f'D_fprs_val.pt']))
                            D_tprs_val = torch.load(
                                '/'.join([temperature_folder, f'D_tprs_val.pt']))
                            D_thresholds_val = torch.load(
                                '/'.join([temperature_folder, f'D_thresholds_val.pt']))
                            D_fpr_val = torch.load(
                                '/'.join([temperature_folder, f'D_fpr_val.pt']))
                            D_tpr_val = torch.load(
                                '/'.join([temperature_folder, f'D_tpr_val.pt']))
                            D_threshold_val = torch.load(
                                '/'.join([temperature_folder, f'D_threshold_val.pt']))
                            D_auc_val = torch.load(
                                '/'.join([temperature_folder, f'D_auc_val.pt']))
                            if D_auc_val > max_auc:
                                max_auc = D_auc_val
                                fpr_max_auc = D_fpr_val
                                tpr_max_auc = D_tpr_val
                                temperature_max_auc = temperature
                                magnitude_max_auc = magnitude
                            elif D_auc_val == max_auc:
                                if D_fpr_val < fpr_max_auc:
                                    max_auc = D_auc_val
                                    fpr_max_auc = D_fpr_val
                                    tpr_max_auc = D_tpr_val
                                    temperature_max_auc = temperature
                                    magnitude_max_auc = magnitude

                            # plot roc curve in the first subplot
                            ax = axs[0]
                            ax.set_title('ROC curve validation')
                            ax.plot(D_fprs_val, D_tprs_val)
                    # plot diagonal line in the first subplot
                    ax = axs[0]
                    ax.plot([0, 1], [0, 1], linestyle='--', color='black')
                    # put marker x on the point with the highest auc with coordinates (fpr_max_auc, tpr_max_auc)
                    ax.plot(fpr_max_auc, tpr_max_auc, marker='x', color='red')
                    # connect axises with the point with the highest auc
                    ax.plot([0, fpr_max_auc], [tpr_max_auc, tpr_max_auc],
                            linestyle='--', color='black')
                    ax.plot([fpr_max_auc, fpr_max_auc], [0, tpr_max_auc],
                            linestyle='--', color='black')
                    # annotate the point with the highest auc
                    ax.annotate(f"({fpr_max_auc:.2f}, {tpr_max_auc:.2f})",
                                (fpr_max_auc, tpr_max_auc))
                    # add axis labels
                    ax.set_xlabel('False Positive Rate')
                    ax.set_ylabel('True Positive Rate')
                    # set axis limits
                    ax.set_xlim(0, 1)
                    ax.set_ylim(0, 1)
                    print(
                        f"The best auroc is {max_auc:.2f} with magnitude {magnitude_max_auc} and temperature {temperature_max_auc}, for seed {seed}, r {r}, lbd {lbd}")
###############################################################################################################
                    for magnitude in [magnitude_max_auc]:
                        magnitude_folder = os.path.join(
                            dest_folder_seed, f"magnitude_{magnitude}")
                        for temperature in [temperature_max_auc]:
                            temperature_folder = os.path.join(
                                magnitude_folder, f"temperature_{temperature}")
                            D_fprs_test = torch.load(
                                '/'.join([temperature_folder, f'D_fprs_test.pt']))
                            D_tprs_test = torch.load(
                                '/'.join([temperature_folder, f'D_tprs_test.pt']))
                            D_thresholds_test = torch.load(
                                '/'.join([temperature_folder, f'D_thresholds_test.pt']))
                            D_fpr_test = torch.load(
                                '/'.join([temperature_folder, f'D_fpr_test.pt']))
                            D_tpr_test = torch.load(
                                '/'.join([temperature_folder, f'D_tpr_test.pt']))
                            D_threshold_test = torch.load(
                                '/'.join([temperature_folder, f'D_threshold_test.pt']))
                            D_auc_test = torch.load(
                                '/'.join([temperature_folder, f'D_auc_test.pt']))

                            # plot roc curve in the second subplot
                            ax = axs[1]
                            ax.set_title('ROC curve test')
                            ax.plot(D_fprs_test, D_tprs_test)
                            # plot diagonal line in the second subplot
                            ax = axs[1]
                            ax.plot([0, 1], [0, 1], linestyle='--', color='black')
                            # put marker x on the point with the highest auc with coordinates (D_fpr_test, D_tpr_test)
                            ax.plot(D_fpr_test, D_tpr_test, marker='x', color='red')
                            # connect axises with the point with the highest auc
                            ax.plot([0, D_fpr_test], [D_tpr_test, D_tpr_test],
                                    linestyle='--', color='black')
                            ax.plot([D_fpr_test, D_fpr_test], [0, D_tpr_test],
                                    linestyle='--', color='black')
                            # annotate the point with the highest auc
                            ax.annotate(f"({D_fpr_test:.2f}, {D_tpr_test:.2f})",
                                        (D_fpr_test, D_tpr_test))
                            # add axis labels
                            ax.set_xlabel('False Positive Rate')
                            ax.set_ylabel('True Positive Rate')
                            # set axis limits
                            ax.set_xlim(0, 1)
                            ax.set_ylim(0, 1)
                            print(
                                f"The best auroc in test is {D_auc_test:.2f} with magnitude {magnitude_max_auc} and temperature {temperature_max_auc}, for seed {seed}, r {r}, lbd {lbd}")
                            # show the plot
                            plt.show()
