In [None]:
import sys
sys.path.append('../..')

In [None]:
# add the parent directory to the path
import matplotlib.pyplot as plt
from tools import data_tools
import torch


config = data_tools.read_config(
    'corruption_doctor_config.yaml')

model_name = config["model_name"]= 'resnet34_custom'
match_dataset_name = config["match_dataset_name"]
corrupted_dataset_name = config["corrupted_dataset_name"]
model_seed = config["model_seed"]
data_path = config["data_path"]
device_id = config["device_id"]
magnitudes = config["magnitudes"]
temperatures = config["temperatures"]
batch_size = config["batch_size"]
rs = config["rs"]
seeds = config["seeds"] = [1]
corruptions = config["corruptions"] = ['brightness']
intensities = config["intensities"] = [1]

# print the config one by one
for key, value in config.items():
    print(key, value)

# set the device to cpu
device = torch.device("cpu")


In [None]:
for seed in seeds:
    for r in rs:
        for corruption in corruptions:
            for intensity in intensities:
                # create a figure 
                # the figure has 2 subplots side by side
                fig, ax = plt.subplots(1, 2, figsize=(15, 5))
                max_auc = float('-inf')
                fpr_at_95_tpr_at_max_auc = None
                tpr_at_max_auc = None
                magnitude_at_max_auc = None
                temperature_at_max_auc = None 
                
                for magnitude in magnitudes:
                    for temperature in temperatures:
                        dest_folder = f'{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}/results/r_{r}/seed_{seed}'
                        final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'
                        doctor_val_fprs = torch.load(
                            f'{final_dest_folder}/doctor_val_fprs.pt')
                        doctor_val_tprs = torch.load(
                            f'{final_dest_folder}/doctor_val_tprs.pt')
                        doctor_val_thresholds = torch.load(
                            f'{final_dest_folder}/doctor_val_thresholds.pt')
                        doctor_val_fpr = torch.load(
                            f'{final_dest_folder}/doctor_val_fpr.pt')
                        doctor_val_tpr = torch.load(
                            f'{final_dest_folder}/doctor_val_tpr.pt')
                        doctor_val_threshold = torch.load(
                            f'{final_dest_folder}/doctor_val_threshold.pt')
                        doctor_val_auc = torch.load(
                            f'{final_dest_folder}/doctor_val_auc.pt')
                            
                        if doctor_val_auc > max_auc:
                            max_auc = doctor_val_auc
                            fpr_at_95_tpr_at_max_auc = doctor_val_fpr
                            tpr_at_max_auc = doctor_val_tpr
                            magnitude_at_max_auc = magnitude
                            temperature_at_max_auc = temperature
                        elif doctor_val_auc == max_auc:
                            if doctor_val_fpr < fpr_at_95_tpr_at_max_auc:
                                fpr_at_95_tpr_at_max_auc = doctor_val_fpr
                                tpr_at_max_auc = doctor_val_tpr
                                magnitude_at_max_auc = magnitude
                                temperature_at_max_auc = temperature

                        # plot roc curve
                        ax[0].plot(doctor_val_fprs, doctor_val_tprs,
                                label=f"r={r}, magnitude={magnitude}, temperature={temperature}")
                        # plot diagonal line

                ax[0].plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')

                # put a marker x at the max auc
                ax[0].plot(fpr_at_95_tpr_at_max_auc, tpr_at_max_auc, marker='x', color='green',
                        label=f"max auc: {max_auc:.2f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}")
                # connext y axis with a line at 95 tpr and min fpr dashed
                ax[0].plot([fpr_at_95_tpr_at_max_auc, fpr_at_95_tpr_at_max_auc], [0,
                        tpr_at_max_auc], linestyle='--', color='green')
                # connext y axis with a line at 95 tpr and min fpr dashed
                ax[0].plot([0, fpr_at_95_tpr_at_max_auc], [tpr_at_max_auc,
                        tpr_at_max_auc], linestyle='--', color='green')
                # annotate the max auc marker with the fpr at auc and tpr at auc
                ax[0].annotate(f"(fpr:{fpr_at_95_tpr_at_max_auc:.2f}, tpr:{tpr_at_max_auc:.2f})",
                            (fpr_at_95_tpr_at_max_auc, tpr_at_max_auc-.07))

                # plot legend
                # ax.legend()
                # set title for the entire figure
                fig.suptitle(f"ROC Curve for {corruption} {intensity} on {match_dataset_name} to {corrupted_dataset_name} {model_name} model seed {model_seed} seed {seed} r {r} corruption {corruption} intensity {intensity}")
                # plot x label
                ax[0].set_xlabel("False Positive Rate")
                # plot y label
                ax[0].set_ylabel("True Positive Rate")
                # plot grid 
                ax[0].grid()
                # xlim to 0 to 1
                ax[0].set_xlim(0, 1)
                # ylim to 0 to 1
                ax[0].set_ylim(0, 1)

                dest_folder = f'{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}/results/r_{r}/seed_{seed}'
                final_dest_folder = f'{dest_folder}/magnitude_{magnitude_at_max_auc}/temperature_{temperature_at_max_auc}'

                doctor_test_fprs = torch.load(
                            f'{final_dest_folder}/doctor_test_fprs.pt')
                doctor_test_tprs = torch.load(
                    f'{final_dest_folder}/doctor_test_tprs.pt')
                doctor_test_thresholds = torch.load(
                    f'{final_dest_folder}/doctor_test_thresholds.pt')
                doctor_test_fpr = torch.load(
                    f'{final_dest_folder}/doctor_test_fpr.pt')
                doctor_test_tpr = torch.load(
                    f'{final_dest_folder}/doctor_test_tpr.pt')
                doctor_test_threshold = torch.load(
                    f'{final_dest_folder}/doctor_test_threshold.pt')
                doctor_test_auc = torch.load(
                    f'{final_dest_folder}/doctor_test_auc.pt')
                
                # plot roc curve
                ax[1].plot(doctor_test_fprs, doctor_test_tprs,
                        label=f"r={r}, magnitude={magnitude}, temperature={temperature}")
                # put a marker at the doctor test fpr and tpr
                ax[1].plot(doctor_test_fpr, doctor_test_tpr, marker='o', color='green',
                        label=f"test fpr: {doctor_test_fpr:.2f}, test tpr: {doctor_test_tpr:.2f}")
                # connext y axis with a line at doctor test fpr and tpr dashed
                ax[1].plot([doctor_test_fpr, doctor_test_fpr], [0,
                        doctor_test_tpr], linestyle='--', color='green')
                # connext x axis with a line at doctor test fpr and tpr dashed
                ax[1].plot([0, doctor_test_fpr], [doctor_test_tpr,
                        doctor_test_tpr], linestyle='--', color='green')
                # set the x axis to 0 to 1
                ax[1].set_xlim(0, 1)
                # set the y axis to 0 to 1
                ax[1].set_ylim(0, 1)
                # annotate the doctor test fpr and and tpr with the values
                ax[1].annotate(f"(fpr:{doctor_test_fpr:.2f}, tpr:{doctor_test_tpr:.2f})",
                            (doctor_test_fpr, doctor_test_tpr))
                ax[0].plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')

                plt.show()


plt.close('all')
