In [1]:
from src.dataset import get_train_val_loaders, get_test_loader
from src.models import MODEL_DICT
from src.training import inference_aggregator_loop, get_best_threshold

import torch

model_name = "MobileNetV2Tab"
weights_path = [
    "saved_models/run_2024-03-04_19-26-35_MobileNetV2Tab/fold_0_epoch_12_acc_0.9310344827586207.pth",
    "saved_models/run_2024-03-04_19-26-35_MobileNetV2Tab/fold_1_epoch_12_acc_0.8516483516483516.pth",
    "saved_models/run_2024-03-04_19-26-35_MobileNetV2Tab/fold_2_epoch_8_acc_0.8722527472527473.pth",
    "saved_models/run_2024-03-04_19-26-35_MobileNetV2Tab/fold_3_epoch_0_acc_0.875.pth",
]

loss_function = torch.nn.BCEWithLogitsLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_loader = get_test_loader(batch_size=32)

test_predictions_list = []
ids_list = []
balanced_accuracy_list = []

for fold in range(4):
    model = MODEL_DICT[model_name]().to(device)
    model.load_state_dict(torch.load(weights_path[fold]))
    model.eval()

    train_loader, val_loader = get_train_val_loaders(
        batch_size=32, num_workers=4, pin_memory=True, fold_id=fold, fold_numbers=4
    )

    patient_labels, aggregated_predictions, unique_ids, val_loss = (
        inference_aggregator_loop(model, val_loader, device, None)
    )

    thresholds_sigmoid, thresholds_logits, balanced_accuracy = get_best_threshold(
        patient_labels, aggregated_predictions, False
    )

    _, test_predictions, unique_ids, _ = inference_aggregator_loop(
        model, test_loader, device, None
    )

    test_predictions_list.append((test_predictions > thresholds_logits).astype(int))
    ids_list.append(unique_ids)
    balanced_accuracy_list.append(balanced_accuracy)

print(balanced_accuracy_list)
print(sum(balanced_accuracy_list) / len(balanced_accuracy_list))



Unique labels: [-1.] Counts: [3258]
Train dataset:
Unique labels: [0. 1.] Counts: [1957 8228]
Val dataset:
Unique labels: [0. 1.] Counts: [ 635 2633]


100%|██████████| 103/103 [00:45<00:00,  2.28it/s]
100%|██████████| 102/102 [00:33<00:00,  3.00it/s]


Train dataset:
Unique labels: [0. 1.] Counts: [1944 8228]
Val dataset:
Unique labels: [0. 1.] Counts: [ 648 2633]


100%|██████████| 103/103 [00:21<00:00,  4.72it/s]
100%|██████████| 102/102 [00:11<00:00,  8.56it/s]


Train dataset:
Unique labels: [0. 1.] Counts: [1877 8048]
Val dataset:
Unique labels: [0. 1.] Counts: [ 715 2813]


100%|██████████| 111/111 [00:21<00:00,  5.06it/s]
100%|██████████| 102/102 [00:11<00:00,  8.88it/s]


Train dataset:
Unique labels: [0. 1.] Counts: [1998 8079]
Val dataset:
Unique labels: [0. 1.] Counts: [ 594 2782]


100%|██████████| 106/106 [00:21<00:00,  4.97it/s]
100%|██████████| 102/102 [00:10<00:00,  9.92it/s]

[0.9310344827586207, 0.8516483516483516, 0.8722527472527473, 0.875]
0.8824838954149299





In [2]:
print(balanced_accuracy_list)

[0.9310344827586207, 0.8516483516483516, 0.8722527472527473, 0.875]


In [4]:
import numpy as np
import pandas as pd

only_keep = [0, 1, 2, 3]


test_predictions = np.stack(
    [pred for i, pred in enumerate(test_predictions_list) if i in only_keep]
).mean(axis=0)
merged_predictions = (test_predictions >= 0.5).astype(int)

ids = ids_list[0]
ids = [f"P{i}" for i in ids]


submission = pd.DataFrame({"Id": ids, "Predicted": merged_predictions})

submission.to_csv("submission.csv", index=False)