# Evaluation

In [None]:
import json
from main_classification import get_args_parser
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from trainer import test_classification
from evaluation.plex_metrics import plex_evaluate
from dataloader import  *
from utils import metric_AUROC

In [None]:
# Parameters
args = get_args_parser(main_args=False).get_default_values()
args.device = "mps"
args.data_dir = "/Users/felixkrones/python_projects/data/ChestXpert/" # /Users/felixkrones/python_projects/data/NIH/images/ /Users/felixkrones/python_projects/data/ChestXpert/ /Users/felixkrones/python_projects/data/Padchest/0/
args.test_list = "dataset/CheXpert_valid_official_frontal.csv" # dataset/Xray14_test_official.txt CheXpert_valid_official_frontal.csv CheXpert_test_Glocker.csv PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv
args.metadata_file = ""
args.model_name = "vit_small"
args.proxy_dir = "/Users/felixkrones/python_projects/models/gmml_nih1000e_timm_run_4_best.pth.tar"
diseases_model = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule',
                    'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
                    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
diseases_to_test = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pneumonia", "Pneumothorax"]
index_to_test_model = [diseases_model.index(disease) for disease in diseases_to_test]

In [None]:
# Define eval params
eval_params = {
  "decision_threshold": 0.5,
  "selective_threshold": 0.1,
  "independent_reg_variable": "StudyDate",
  "subpopulation_groups": ["sex_label", "race_label", "age_binned"],
  "ece_num_bins": 15,
}

In [None]:
# Get data
if "nih" in args.data_dir.lower():
    dataset_test = ChestXray14Dataset(images_path=args.data_dir, file_path=args.test_list, augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment))
    diseases = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule',
                    'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
                    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']
    index_to_test_dataset = [diseases.index(disease) for disease in diseases_to_test]
elif "chestxpert" in args.data_dir.lower():
    dataset_test = CheXpertDataset(images_path=args.data_dir, file_path=args.test_list, augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment))
    diseases = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
                           'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax',
                           'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
    index_to_test_dataset = [diseases.index(disease) for disease in diseases_to_test]
elif "padchest" in args.data_dir.lower():
    diseases_to_test = [disease.replace("Edema", "Pulmonary Edema") for disease in diseases_to_test]
    dataset_test = PadchestDataset(images_path=args.data_dir, file_path=args.test_list, augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment), diseases_to_test=diseases_to_test)
    diseases = dataset_test.possible_labels
    index_to_test_dataset = [diseases.tolist().index(disease.lower()) for disease in diseases_to_test]
else:
    raise ValueError(f"Dataset {args.data_dir} not supported")
print(f"Dataset size: {len(dataset_test)}")
print(f"index_to_test_dataset: {index_to_test_dataset}")
print(f"index_to_test_model: {index_to_test_model}")
if not len(diseases_to_test)==len(index_to_test_dataset):
    print(f"len(index_to_test_dataset): {len(index_to_test_dataset)}")
    print(f"len(diseases_to_test): {len(diseases_to_test)}")
    raise ValueError("Number of classes does not match the number of diseases to test")

In [None]:
# Get dataloader and model
device = torch.device(args.device)
cudnn.benchmark = True
data_loader_test = DataLoader(dataset=dataset_test, batch_size=args.batch_size, shuffle=False,
                                num_workers=args.workers, pin_memory=True)
saved_model = os.path.join(args.proxy_dir)

In [None]:
# Get predictions
y_test, p_test = test_classification(saved_model, data_loader_test, device, args)

In [None]:
# Filter predictions
y_test_filtered = y_test[:, index_to_test_dataset]
p_test_filtered = p_test[:, index_to_test_model]

In [None]:
# Default metrics
all_results = metric_AUROC(y_test_filtered, p_test_filtered, len(diseases_to_test))
mean_over_all_classes = np.array(all_results).mean()
print(round(mean_over_all_classes, 4))

In [None]:
# Evaluate
eval_metrics = plex_evaluate(
    preds=p_test_filtered.cpu().numpy(),
    target_labels=y_test_filtered.cpu().numpy(),
    eval_args=eval_params,
    meta_data=pd.read_csv(args.metadata_file) if args.metadata_file else None,
    classes=diseases_to_test,
    underdiagnosis_label="No Finding",
)

In [None]:
for key, value in eval_metrics.items():
    print(f"{key}: {value}")

In [None]:
# Export
output_path = "./Outputs/Classification", args.test_list.split('/')[-1].split('.')[0]
output_file = os.path.join(output_path, f"{args.proxy_dir.split('/')[-1].split('.')[0]}_popar_results.txt")
if not os.path.exists('/'.join(output_file.split('/')[:-1])):
    os.makedirs(output_path)
with open(output_file, "w") as f:
    json.dump(eval_metrics, f, indent=4)