In [None]:
import json
import torch
import module
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from module import AKIPredictionModel, test
from sklearn.calibration import calibration_curve

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

https://nalepae.github.io/pandarallel/troubleshooting/


In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.allow_tf32 = True

Data

In [None]:
calibration_dataset = torch.load("processed/datasets.pt")['ABC']['calibration']
test_dataset = torch.load("processed/datasets.pt")['ABC']['test']

dataloaders = []
dataset_names = ["calibration", "test"]

for name in dataset_names:
    dataloader = DataLoader(eval(name + "_dataset"), batch_size=1, shuffle=True, drop_last=False)
    dataloaders.append(dataloader)

calibration_dataloader, test_dataloader = dataloaders

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("model/best_params.json", "r") as fp:
    params = json.load(fp)

with open("model/best_ckpt_path.txt", "r") as fp:
    ckpt_path = fp.read().strip()

print("▶  Using checkpoint:", ckpt_path)

model = AKIPredictionModel(
    hidden_size          = params['hidden_size'],
    embedding_size       = params['embedding_size'],
    recurrent_num_layers = params['recurrent_num_layers'],
    embedding_num_layers = params['embedding_num_layers'],
    activation_type      = params['activation_type'],
    recurrent_type       = params['recurrent_type'],
    seq_len              = 56,
    LN                   = bool(params['LN']),
    highway_network      = bool(params['highway_network']),
    numeric_input_size   = test_dataloader.dataset[0].tensors[0].shape[-1],
    presence_input_size  = test_dataloader.dataset[0].tensors[1].shape[-1],
    CB                   = bool(params['CB']),
).to(device)

model.load_state_dict(torch.load(ckpt_path, map_location=device))

main_dataset, sub_dataset = test(model, test_dataloader)
main_dataset_cal, sub_dataset_cal = test(model, calibration_dataloader)

print("✅ evaluation complelte & .pt files are saved")

# Dataloader

In [None]:
main_dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, drop_last=True) for dataset in main_dataset]
sub_dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, drop_last=True) for dataset in sub_dataset]

main_dataloader_6h ,main_dataloader_12h ,main_dataloader_18h ,main_dataloader_24h,main_dataloader_30h ,main_dataloader_36h ,main_dataloader_42h ,main_dataloader_48h = main_dataloaders
sub_dataloader_1 ,sub_dataloader_2 ,sub_dataloader_3, sub_dataloader_3D = sub_dataloaders

In [None]:
main_dataloaders_cal = [DataLoader(dataset, batch_size=1, shuffle=False, drop_last=True) for dataset in main_dataset_cal]
sub_dataloaders_cal = [DataLoader(dataset, batch_size=1, shuffle=False, drop_last=True) for dataset in sub_dataset_cal]

main_dataloader_6h_cal, main_dataloader_12h_cal, main_dataloader_18h_cal, main_dataloader_24h_cal, main_dataloader_30h_cal, main_dataloader_36h_cal, main_dataloader_42h_cal, main_dataloader_48h_cal = main_dataloaders_cal
sub_dataloader_1_cal, sub_dataloader_2_cal, sub_dataloader_3_cal, sub_dataloader_3D_cal = sub_dataloaders_cal

# Probability Curve

In [None]:
# Main dataloaders
plt.figure(figsize=(8, 6), dpi=300)

main_dataloaders = [
    (main_dataloader_6h, "6h"),
    (main_dataloader_12h, "12h"),
    (main_dataloader_18h, "18h"),
    (main_dataloader_24h, "24h"),
    (main_dataloader_30h, "30h"),
    (main_dataloader_36h, "36h"),
    (main_dataloader_42h, "42h"),
    (main_dataloader_48h, "48h"),
]

for dataloader, label in main_dataloaders:
    y_true, y_scores = module.step_ROC(dataloader)
    fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_scores, n_bins=10)
    sns.lineplot(x=mean_predicted_value, y=fraction_of_positives, marker='s', label=label)

plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.title("Calibration Curve")
plt.legend(loc='upper left')
plt.show()

In [None]:
# Sub dataloaders
plt.figure(figsize=(8, 6))

sub_dataloaders = [
    (sub_dataloader_1, "1≥"),
    (sub_dataloader_2, "2≥"),
    (sub_dataloader_3, "3≥"),
    (sub_dataloader_3D, "3D"),
]

for dataloader, label in sub_dataloaders:
    y_true, y_scores = module.step_ROC(dataloader)
    fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_scores, n_bins=10)
    sns.lineplot(x=mean_predicted_value, y=fraction_of_positives, marker='s', label=label)

plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.title("Calibration Curve")
plt.legend()
plt.show()

# Calibration

In [None]:
main_targets = [
    ("6h", main_dataloader_6h_cal, main_dataloader_6h),
    ("12h", main_dataloader_12h_cal, main_dataloader_12h),
    ("18h", main_dataloader_18h_cal, main_dataloader_18h),
    ("24h", main_dataloader_24h_cal, main_dataloader_24h),
    ("30h", main_dataloader_30h_cal, main_dataloader_30h),
    ("36h", main_dataloader_36h_cal, main_dataloader_36h),
    ("42h", main_dataloader_42h_cal, main_dataloader_42h),
    ("48h", main_dataloader_48h_cal, main_dataloader_48h),
]

sub_targets = [
    ("1≥", sub_dataloader_1_cal, sub_dataloader_1),
    ("2≥", sub_dataloader_2_cal, sub_dataloader_2),
    ("3≥", sub_dataloader_3_cal, sub_dataloader_3),
    ("3D", sub_dataloader_3D_cal, sub_dataloader_3D),
]

for label, cal_loader, raw_loader in main_targets:
    calibrated_loader = module.calibration(cal_loader, raw_loader)
    globals()[f"main_dataloader_{label}"] = calibrated_loader
    print(f"▶ [Main] Result after calibration: {label}")
    module.Result(calibrated_loader)

for label, cal_loader, raw_loader in sub_targets:
    key = "sub_dataloader_" + label.replace("≥", "").replace("D", "D")
    calibrated_loader = module.calibration(cal_loader, raw_loader)
    globals()[key] = calibrated_loader
    print(f"▶ [Sub] Result after calibration: {label}")
    module.Result(calibrated_loader)

# Calibration Curve

In [None]:
# Main dataloaders
plt.figure(figsize=(8, 6), dpi=300)

main_dataloaders = [
    (main_dataloader_6h, "6h"),
    (main_dataloader_12h, "12h"),
    (main_dataloader_18h, "18h"),
    (main_dataloader_24h, "24h"),
    (main_dataloader_30h, "30h"),
    (main_dataloader_36h, "36h"),
    (main_dataloader_42h, "42h"),
    (main_dataloader_48h, "48h"),
]

for dataloader, label in main_dataloaders:
    y_true, y_scores = module.step_ROC(dataloader)
    fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_scores, n_bins=10)
    sns.lineplot(x=mean_predicted_value, y=fraction_of_positives, marker='s', label=label)

plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.title("Calibration Curve")
plt.legend(loc='upper left')
plt.show()

In [None]:
# Sub dataloaders
plt.figure(figsize=(8, 6))

sub_dataloaders = [
    (sub_dataloader_1, "1≥"),
    (sub_dataloader_2, "2≥"),
    (sub_dataloader_3, "3≥"),
    (sub_dataloader_3D, "3D"),
]

for dataloader, label in sub_dataloaders:
    y_true, y_scores = module.step_ROC(dataloader)
    fraction_of_positives, mean_predicted_value = calibration_curve(y_true, y_scores, n_bins=10)
    sns.lineplot(x=mean_predicted_value, y=fraction_of_positives, marker='s', label=label)

plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
plt.xlabel("Mean Predicted Probability")
plt.ylabel("Fraction of Positives")
plt.title("Calibration Curve")
plt.legend()
plt.show()