In [None]:
import sys
sys.path.append('../..')

In [None]:
# add the parent directory to the path
import matplotlib.pyplot as plt
from tools import data_tools
import torch

config = data_tools.read_config(
    './mismatch_doctor_config.yaml')

model_name = config["model_name"]
match_dataset_name = config["match_dataset_name"]
mismatch_dataset_name = config["mismatch_dataset_name"]
model_seed = config["model_seed"]
data_path = config["data_path"]
device_id = config["device_id"]
magnitudes = config["magnitudes"]
temperatures = config["temperatures"]
batch_size = config["batch_size"]
rs = config["rs"] = [10]
seeds = config["seeds"] = [1]

# print the config one by one
for key, value in config.items():
    print(key, value)

# set the device to cpu
device = torch.device("cpu")


In [None]:
best_magnitude_dict = {}
best_temperature_dict = {}

for seed in seeds:
    min_fpr_at_95_tpr = float('inf')
    tpr_at_min_fpr = None
    auc_at_min_fpr = None
    magnitude_at_min_fpr = None
    temperature_at_min_fpr = None

    max_auc = float('-inf')
    fpr_at_95_tpr_at_max_auc = None
    tpr_at_max_auc = None
    magnitude_at_max_auc = None
    temperature_at_max_auc = None

    for r in rs:
        fig, ax = plt.subplots(figsize=(10, 5))
        best_magnitude_dict_r = {}
        best_temperature_dict_r = {}  
        
        for magnitude in magnitudes:
            for temperature in temperatures:
                dest_folder = f'./{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/results/r_{r}/seed_{seed}'
                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'
                doctor_val_fprs = torch.load(
                    f'{final_dest_folder}/doctor_val_fprs.pt')
                doctor_val_tprs = torch.load(
                    f'{final_dest_folder}/doctor_val_tprs.pt')
                doctor_val_thresholds = torch.load(
                    f'{final_dest_folder}/doctor_val_thresholds.pt')
                doctor_val_fpr = torch.load(
                    f'{final_dest_folder}/doctor_val_fpr.pt')
                doctor_val_tpr = torch.load(
                    f'{final_dest_folder}/doctor_val_tpr.pt')
                doctor_val_threshold = torch.load(
                    f'{final_dest_folder}/doctor_val_threshold.pt')
                doctor_val_auc = torch.load(
                    f'{final_dest_folder}/doctor_val_auc.pt')
                if doctor_val_fpr < min_fpr_at_95_tpr:
                    min_fpr_at_95_tpr = doctor_val_fpr
                    tpr_at_min_fpr = doctor_val_tpr
                    auc_at_min_fpr = doctor_val_auc
                    magnitude_at_min_fpr = magnitude
                    temperature_at_min_fpr = temperature
                elif doctor_val_fpr == min_fpr_at_95_tpr:
                    if doctor_val_auc > auc_at_min_fpr:
                        tpr_at_min_fpr = doctor_val_tpr
                        auc_at_min_fpr = doctor_val_auc
                        magnitude_at_min_fpr = magnitude
                        temperature_at_min_fpr = temperature
                    
                if doctor_val_auc > max_auc:
                    max_auc = doctor_val_auc
                    fpr_at_95_tpr_at_max_auc = doctor_val_fpr
                    tpr_at_max_auc = doctor_val_tpr
                    magnitude_at_max_auc = magnitude
                    temperature_at_max_auc = temperature
                elif doctor_val_auc == max_auc:
                    if doctor_val_fpr < fpr_at_95_tpr_at_max_auc:
                        fpr_at_95_tpr_at_max_auc = doctor_val_fpr
                        tpr_at_max_auc = doctor_val_tpr
                        magnitude_at_max_auc = magnitude
                        temperature_at_max_auc = temperature

                # plot roc curve
                ax.plot(doctor_val_fprs, doctor_val_tprs,
                        label=f"r={r}, magnitude={magnitude}, temperature={temperature}")
                # plot diagonal line

        ax.plot([0, 1], [0, 1], linestyle='--', label='Random Guess')

        # put a marker x at the min fpr at 95 tpr
        ax.plot(min_fpr_at_95_tpr, tpr_at_min_fpr, marker='x', color='orange',
                label=f"min fpr at 95 tpr: {min_fpr_at_95_tpr:.2f}, magnitude={magnitude_at_min_fpr}, temperature={temperature_at_min_fpr}")
        # connext x axis with a line at 95 tpr and min fpr dashed
        ax.plot([min_fpr_at_95_tpr, min_fpr_at_95_tpr], [0, tpr_at_min_fpr],
                linestyle='--', color='orange')
        # connext y axis with a line at 95 tpr and min fpr dashed
        ax.plot([0, min_fpr_at_95_tpr], [tpr_at_min_fpr, tpr_at_min_fpr],
                linestyle='--', color='orange')
        # annotate the min fpr at 95 tpr with the min fpr at 95 tpr and tpr at min fpr
        ax.annotate(f"(fpr: {min_fpr_at_95_tpr:.2f}, tpr: {tpr_at_min_fpr:.2f})",
                    (min_fpr_at_95_tpr-.1, tpr_at_min_fpr+.02))

        # put a marker x at the max auc
        ax.plot(fpr_at_95_tpr_at_max_auc, tpr_at_max_auc, marker='x', color='green',
                label=f"max auc: {max_auc:.2f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}")
        # connext y axis with a line at 95 tpr and min fpr dashed
        ax.plot([fpr_at_95_tpr_at_max_auc, fpr_at_95_tpr_at_max_auc], [0,
                tpr_at_max_auc], linestyle='--', color='green')
        # connext y axis with a line at 95 tpr and min fpr dashed
        ax.plot([0, fpr_at_95_tpr_at_max_auc], [tpr_at_max_auc,
                tpr_at_max_auc], linestyle='--', color='green')
        # annotate the max auc marker with the fpr at auc and tpr at auc
        ax.annotate(f"(fpr:{fpr_at_95_tpr_at_max_auc:.2f}, tpr:{tpr_at_max_auc:.2f})",
                    (fpr_at_95_tpr_at_max_auc, tpr_at_max_auc-.07))

        # plot legend
        # ax.legend()
        # plot title
        ax.set_title(
            f"ROC Curve for {model_name} on {match_dataset_name} to {mismatch_dataset_name} with seed {seed}")
        # plot x label
        ax.set_xlabel("False Positive Rate")
        # plot y label
        ax.set_ylabel("True Positive Rate")
        # plot grid 
        ax.grid()
        # xlim to 0 to 1
        ax.set_xlim(0, 1)
        # ylim to 0 to 1
        ax.set_ylim(0, 1)
        plt.show()

        best_magnitude_dict_r[f"r_{r}"]=magnitude_at_max_auc
        best_temperature_dict_r[f"r_{r}"]=temperature_at_max_auc

        print(f"min fpr at 95 tpr: {min_fpr_at_95_tpr:.3f}, magnitude={magnitude_at_min_fpr}, temperature={temperature_at_min_fpr}, tpr={tpr_at_min_fpr:.3f}, auc={auc_at_min_fpr:.3f}")
        print(f"max auc: {max_auc:.3f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}, fpr={fpr_at_95_tpr_at_max_auc:.3f}, tpr={tpr_at_max_auc:.3f}")
    
    best_temperature_dict[f"seed_{seed}"]=best_temperature_dict_r
    best_magnitude_dict[f"seed_{seed}"]=best_magnitude_dict_r

# close all the plot
plt.close('all')

print(f"best magnitude dict: {best_magnitude_dict}")
print(f"best temperature dict: {best_temperature_dict}")


In [None]:
for seed in seeds:
    for r in rs:
        fig, ax = plt.subplots(figsize=(10, 5))
        for magnitude in [best_magnitude_dict[f"seed_{seed}"][f"r_{r}"]]:
            for temperature in [best_temperature_dict[f"seed_{seed}"][f"r_{r}"]]:
                dest_folder = f'./{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/results/r_{r}/seed_{seed}'
                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'
                doctor_test_fprs = torch.load(
                    f'{final_dest_folder}/doctor_test_fprs.pt')
                doctor_test_tprs = torch.load(
                    f'{final_dest_folder}/doctor_test_tprs.pt')
                doctor_test_thresholds = torch.load(
                    f'{final_dest_folder}/doctor_test_thresholds.pt')
                doctor_test_fpr = torch.load(
                    f'{final_dest_folder}/doctor_test_fpr.pt')
                doctor_test_tpr = torch.load(
                    f'{final_dest_folder}/doctor_test_tpr.pt')
                doctor_test_threshold = torch.load(
                    f'{final_dest_folder}/doctor_test_threshold.pt')
                doctor_test_auc = torch.load(
                    f'{final_dest_folder}/doctor_test_auc.pt')
                # plot roc curve
                ax.plot(doctor_test_fprs, doctor_test_tprs,
                        label=f"r={r}, magnitude={magnitude}, temperature={temperature}")
                # put a marker at the doctor test fpr and tpr
                ax.plot(doctor_test_fpr, doctor_test_tpr, marker='o', color='green',
                        label=f"test fpr: {doctor_test_fpr:.2f}, test tpr: {doctor_test_tpr:.2f}")
                # connext y axis with a line at doctor test fpr and tpr dashed
                ax.plot([doctor_test_fpr, doctor_test_fpr], [0,
                        doctor_test_tpr], linestyle='--', color='green')
                # connext x axis with a line at doctor test fpr and tpr dashed
                ax.plot([0, doctor_test_fpr], [doctor_test_tpr,
                        doctor_test_tpr], linestyle='--', color='green')
                # set the x axis to 0 to 1
                ax.set_xlim(0, 1)
                # set the y axis to 0 to 1
                ax.set_ylim(0, 1)
                # annotate the doctor test fpr and and tpr with the values
                ax.annotate(f"(fpr:{doctor_test_fpr:.2f}, tpr:{doctor_test_tpr:.2f})",
                            (doctor_test_fpr, doctor_test_tpr))
                # print the doctor test auc
                print(f"test auc: {doctor_test_auc:.3f}, magnitude={magnitude}, temperature={temperature}, fpr={doctor_test_fpr:.3f}, tpr={doctor_test_tpr:.3f}")

        # plot random guess line
        ax.plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')

plt.show()
# close all the plot
plt.close('all')
