# Evaluation

In [None]:
import json
from main_classification import get_args_parser
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from trainer import test_classification
from evaluation.plex_metrics import plex_evaluate
from dataloader import  *

In [None]:
# Parameters
args = get_args_parser()
args.dataset = "ChestXray14"
args.data_dir = "/home/ubuntu/data/NIHChest14/images/"
args.test_list = "dataset/Xray14_test_official.txt"
args.metadata_file = ""
args.model = "ResNet50"
args.proxy_dir = "/home/ubuntu/models/moco_v3_r50_99e_nih_deit_prepped.pth"

In [None]:
# Define diseases
diseases = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pneumonia", "Pneumothorax", 'No Finding']

In [None]:
# Define eval params
eval_params = {
  "selective_threshold": 0.1,
  "independent_reg_variable": "StudyDate",
  "subpopulation_groups": ["sex_label", "race_label", "age_binned"],
  "ece_num_bins": 15,
}

In [None]:
# Get data
if args.data_set == "ChestXray14":
    dataset_test = dataset_test = ChestXray14Dataset_general(images_path=args.data_dir, file_path=args.test_list,augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment), possible_labels=diseases)
elif args.data_set == "CheXpert":
    dataset_test = CheXpertDataset(images_path=args.data_dir, file_path=args.test_list, augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment, nc=args.nc), uncertain_label=args.uncertain_label, unknown_label=args.unknown_label, nc=args.nc)
elif args.data_set == "padchest":
    dataset_test = PadchestDataset(images_path=args.data_dir, file_path=args.test_list, augment=build_transform_classification(normalize=args.normalization, mode="test", test_augment=args.test_augment))

In [None]:
# Get dataloader and model
device = torch.device(args.device)
cudnn.benchmark = True
data_loader_test = DataLoader(dataset=dataset_test, batch_size=args.batch_size, shuffle=False,
                                num_workers=args.workers, pin_memory=True)
saved_model = os.path.join(args.proxy_dir)

In [None]:
# Get predictions
y_test, p_test = test_classification(saved_model, data_loader_test, device, args)


In [None]:
# Evaluate
eval_metrics = plex_evaluate(
    p_out=p_test,
    target_labels=y_test,
    eval_args=eval_params["eval_args"],
    meta_data=pd.read(args.metadata_file),
    classes=diseases,
    underdiagnosis_label="No Finding",
)
print(eval_metrics)

In [None]:
# Export
output_path = "./Outputs/Classification", args.test_list.split('/')[-1].split('.')[0]
output_file = os.path.join(output_path, f"{args.proxy_dir.split('/')[-1].split('.')[0]}_popar_results.txt")
if not os.path.exists('/'.join(output_file.split('/')[:-1])):
    os.makedirs(output_path)
with open(output_file, "w") as f:
    json.dump(eval_metrics, f, indent=4)