# Evaluation

## 0. Initiate

In [None]:
import json
from main_classification import get_args_parser
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from trainer import test_classification
from evaluation.plex_metrics import plex_evaluate
from dataloader import *
from utils import metric_AUROC

In [None]:
# Parameters
args = get_args_parser(main_args=False).get_default_values()
args.device = "mps"

## 1. Define parameters

In [None]:
# Data path
args.data_dir = "/Users/felixkrones/python_projects/data/NIH/images/"  # /Users/felixkrones/python_projects/data/NIH/images/ /Users/felixkrones/python_projects/data/ChestXpert/ /Users/felixkrones/python_projects/data/Padchest/0/ /Users/felixkrones/python_projects/data/VinDrCXR/
args.test_list = "dataset/Xray14_test_official.txt"  # dataset/Xray14_test_official.txt CheXpert_valid_official_frontal.csv CheXpert_test_Glocker.csv PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv dataset/VinDrCXR_test_pe_global_one.txt
args.metadata_file = ""
if "CheXpert_valid_official_frontal.csv" == args.test_list:
    args.metadata_file = "dataset/chestxpert_valid_metadata.csv"

In [None]:
# Benchmark models
args.nc = 3
model_list = [
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_random/ResNet50_random_run_0_best.pth.tar"),
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_random/ResNet50_random_run_1_best.pth.tar"),
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_random/ResNet50_random_run_2_best.pth.tar"),
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_imagenet/ResNet50_imagenet_run_0_best.pth.tar"),
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_imagenet/ResNet50_imagenet_run_1_best.pth.tar"),
    ("ResNet50", "/Users/felixkrones/python_projects/models/BenchmarkTransferLearning_f/Classification/ChestXray14/ResNet50_imagenet/ResNet50_imagenet_run_2_best.pth.tar"),
]
diseases_model = [
    "Atelectasis",
    "Cardiomegaly",
    "Effusion",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax",
    "Consolidation",
    "Edema",
    "Emphysema",
    "Fibrosis",
    "Pleural_Thickening",
    "Hernia",
]

In [None]:
# GNMML models nc 1
args.nc = 1
model_list = [
    ("", "2.1"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.1_random_scratch_1D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.1_random_scratch_1D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.1_random_scratch_1D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.1_random_scratch_1D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.1_random_scratch_1D/seed_100/best_checkpoint.pth"),
    ("", "0.2.1"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.1_scratch_NIH_1D_v1/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.1_scratch_NIH_1D_v1/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.1_scratch_NIH_1D_v1/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.1_scratch_NIH_1D_v1/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.1_scratch_NIH_1D_v1/seed_100/best_checkpoint.pth"),
    ("", "3.1"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.1_scratch_MIMIC_1D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.1_scratch_MIMIC_1D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.1_scratch_MIMIC_1D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.1_scratch_MIMIC_1D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.1_scratch_MIMIC_1D/seed_100/best_checkpoint.pth"),
    ("", "3.3"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.3_scratch_OCT_1D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.3_scratch_OCT_1D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.3_scratch_OCT_1D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.3_scratch_OCT_1D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.3_scratch_OCT_1D/seed_100/best_checkpoint.pth"),
    ("", "3.5"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.5_scratch_CovidxCT_1D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.5_scratch_CovidxCT_1D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.5_scratch_CovidxCT_1D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.5_scratch_CovidxCT_1D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.5_scratch_CovidxCT_1D/seed_100/best_checkpoint.pth"),
]
diseases_model = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pleural_Thickening",
    "Pneumonia",
    "Pneumothorax",
]

In [None]:
# GNMML models nc 3
args.nc = 3
model_list = [
    ("", "2.2"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.2_imagenet_scratch_3D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.2_imagenet_scratch_3D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.2_imagenet_scratch_3D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.2_imagenet_scratch_3D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/2.2_imagenet_scratch_3D/seed_100/best_checkpoint.pth"),
    ("", "0.2.6"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.6_timm_imagenet_3D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.6_timm_imagenet_3D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.6_timm_imagenet_3D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.6_timm_imagenet_3D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/0.2.6_timm_imagenet_3D/seed_100/best_checkpoint.pth"),
    ("", "3.2"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.2_timm_MIMIC_3D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.2_timm_MIMIC_3D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.2_timm_MIMIC_3D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.2_timm_MIMIC_3D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.2_timm_MIMIC_3D/seed_100/best_checkpoint.pth"),
    ("", "3.4"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.4_timm_OCT_3D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.4_timm_OCT_3D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.4_timm_OCT_3D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.4_timm_OCT_3D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.4_timm_OCT_3D/seed_100/best_checkpoint.pth"),
    ("", "3.6"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.6_timm_CovidxCT_3D/seed_0/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.6_timm_CovidxCT_3D/seed_11/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.6_timm_CovidxCT_3D/seed_21/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.6_timm_CovidxCT_3D/seed_42/best_checkpoint.pth"),
    ("vit_small", "/Users/felixkrones/python_projects/models/GMML/Finetune/CXR8/3.6_timm_CovidxCT_3D/seed_100/best_checkpoint.pth"),
]
diseases_model = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pleural_Thickening",
    "Pneumonia",
    "Pneumothorax",
]

In [None]:
# Define eval params
eval_params = {
    "selective_threshold": [0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "independent_reg_variable": "StudyDate",
    "subpopulation_groups": ["sex_label", "race_label"],
    "ece_num_bins": 15,
}
decision_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
diseases_to_test = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion"]
diseases_to_test = diseases_model
index_to_test_model = [diseases_model.index(disease) for disease in diseases_to_test]

## 2. Load data

In [None]:
# Get data
if "nih" in args.data_dir.lower():
    dataset_test = ChestXray14Dataset(
        images_path=args.data_dir,
        file_path=args.test_list,
        augment=build_transform_classification(
            normalize=args.normalization, mode="test", test_augment=args.test_augment, nc=args.nc
        ),
        nc=args.nc,
    )
    diseases = [
        "Atelectasis",
        "Cardiomegaly",
        "Effusion",
        "Infiltration",
        "Mass",
        "Nodule",
        "Pneumonia",
        "Pneumothorax",
        "Consolidation",
        "Edema",
        "Emphysema",
        "Fibrosis",
        "Pleural_Thickening",
        "Hernia",
    ]
    index_to_test_dataset = [diseases.index(disease) for disease in diseases_to_test]
elif "chestxpert" in args.data_dir.lower():
    diseases_to_test = [
        disease.replace("Effusion", "Pleural Effusion") for disease in diseases_to_test
    ]
    dataset_test = CheXpertDataset(
        images_path=args.data_dir,
        file_path=args.test_list,
        augment=build_transform_classification(
            normalize=args.normalization, mode="test", test_augment=args.test_augment, nc=args.nc
        ),
        uncertain_label=args.uncertain_label,
        unknown_label=args.unknown_label,
        nc=args.nc,
    )
    diseases = [
        "No Finding",
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices",
    ]
    index_to_test_dataset = [diseases.index(disease) for disease in diseases_to_test]
elif "padchest" in args.data_dir.lower():
    diseases_to_test = [
        disease.replace("Effusion", "Pleural Effusion") for disease in diseases_to_test
    ]
    dataset_test = PadchestDataset(
        images_path=args.data_dir,
        file_path=args.test_list,
        augment=build_transform_classification(
            normalize=args.normalization, mode="test", test_augment=args.test_augment, nc=args.nc
        ),
        diseases_to_test=diseases_to_test,
        nc=args.nc,
    )
    diseases = dataset_test.possible_labels
    index_to_test_dataset = [
        diseases.index(disease.lower()) for disease in diseases_to_test
    ]
elif "vindr" in args.data_dir.lower():
    dataset_test = VinDrCXR(
        images_path=args.data_dir,
        file_path=args.test_list,
        augment=build_transform_classification(
            normalize=args.normalization, mode="test", test_augment=args.test_augment, nc=args.nc
        ),
        nc=args.nc,
    )
    diseases = dataset_test.possible_labels
    index_to_test_dataset = [
        diseases.index(disease.replace("Effusion", "Pleural effusion"))
        for disease in diseases_to_test
    ]
else:
    raise ValueError(f"Dataset {args.data_dir} not supported")
print(f"Dataset size: {len(dataset_test)}")
print(f"index_to_test_dataset: {index_to_test_dataset}")
print(f"index_to_test_model: {index_to_test_model}")
if not len(diseases_to_test) == len(index_to_test_dataset):
    print(f"len(index_to_test_dataset): {len(index_to_test_dataset)}")
    print(f"len(diseases_to_test): {len(diseases_to_test)}")
    raise ValueError("Number of classes does not match the number of diseases to test")

# Get dataloader and model
device = torch.device(args.device)
cudnn.benchmark = True
data_loader_test = DataLoader(
    dataset=dataset_test,
    sampler=torch.utils.data.SequentialSampler(dataset_test),
    batch_size=args.batch_size,
    num_workers=args.workers,
    pin_memory=True,
    drop_last=False
)

## 3. Loop through models

In [None]:
metrics = {}
mean_auc = []
for (model_name, model_path) in model_list:
    if len(model_name) > 0:
        print(f"-------------- Model {model_name} from path: {model_path} --------------")

        # Load model
        saved_model = os.path.join(model_path)
        args.model_name = model_name

        # Get predictions
        y_test, p_test = test_classification(saved_model, data_loader_test, device, args)

        # Filter predictions
        y_test_filtered = y_test[:, index_to_test_dataset].type(torch.int64)
        p_test_filtered = p_test[:, index_to_test_model]

        # For padchest combine all atelectasis labels
        if "padchest" in args.data_dir.lower():
            index_atelectasis = [i for i, d in enumerate(diseases) if "atelectasis" in d.lower()]
            y_test_filtered[:, diseases_to_test.index("Atelectasis")] = torch.max(y_test[:, index_atelectasis]).type(torch.int64)

        # Default metrics
        all_results = metric_AUROC(y_test_filtered, p_test_filtered)
        mean_over_all_classes = np.array([i for i in all_results if i > 0]).mean()

        # Print results
        print(f"diseases_to_test: {diseases_to_test}")
        print(f"index_to_test_dataset: {index_to_test_dataset}")
        try:
            print(f"Count from dataset: {sum(dataset_test.img_label[:, index_to_test_dataset])}")
            print(f"Count from __get__: {sum(y_test_filtered)}")
        except:
            print(
                f"Count: {sum(torch.from_numpy(np.array(dataset_test.img_label))[:, index_to_test_dataset])}"
            )
            print(f"Count from __get__: {sum(y_test_filtered.cpu().numpy())}")
        print(f"AUC: {all_results}")
        print(f"Mean AUC: {round(mean_over_all_classes, 4)}")
        mean_auc.append(mean_over_all_classes)

        # Evaluate
        eval_metrics = []
        for decision_threshold in decision_thresholds:
            eval_params["decision_threshold"] = decision_threshold
            eval_metrics.append(
                plex_evaluate(
                    preds=p_test_filtered.cpu().numpy(),
                    target_labels=y_test_filtered.cpu().numpy(),
                    eval_args=eval_params,
                    meta_data=pd.read_csv(args.metadata_file)
                    if args.metadata_file
                    else None,
                )
            )

        # Save metrics
        metrics[model_path] = eval_metrics

    else:
        print(f"-------------- Starting with experiment {model_path} --------------")
        mean_auc.append(f"-- {model_path} --")

print("Mean AUCs for all models")
print(mean_auc)

## 4. Print and save

In [None]:
for model_path, eval_metrics in metrics.items():
    print(
        f"---------------------------------------- Model path: {model_path} ----------------------------------------"
    )
    for eval_metric in eval_metrics:
        print(
            f"--------------- Decision threshold: {eval_metric['decision_threshold']} ---------------"
        )
        for key, value in eval_metric.items():
            print(f"{key}: {value}")

In [None]:
# Export
output_path = f"./Outputs/Evaluation/{args.test_list.split('/')[-1].split('.')[0]}/{model_paths[0].split('/')[-1].split('_run')[0]}"
output_file = os.path.join(output_path, "results.txt")
if not os.path.exists("/".join(output_file.split("/")[:-1])):
    os.makedirs(output_path)
with open(output_file, "w") as f:
    json.dump(metrics, f, indent=4)