In [1]:
%load_ext autoreload
%autoreload 2
from trainer import LitModel
import torch 
from shared_modules.data_module_all import DataModule
from shared_modules.utils import load_config
from tqdm import tqdm
from monai.metrics import DiceMetric
from shared_modules.plotting import plot_metrics, plot_confusion, plot_difference
from shared_modules.torch_metrics import PicaiMetric
from shared_modules.post_transforms import get_post_transforms
from monai.data import decollate_batch

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_eval



Please cite the following paper when using Report Guided Annotations:

Bosma, J.S., et al. "Semi-supervised learning with report-guided lesion annotation for deep learning-based prostate cancer detection in bpMRI" to be submitted


If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/Report-Guided-Annotation



In [2]:
SAVE_PROB_MAPS = False
SAVE_PREDS = False

config = load_config("config.yaml")
config.data.data_dir = "../../../data/"
config.data.json_list = "../../../json_datalists/picai/all_samples.json"
gpu = 0
config.gpus = [gpu]
config.cache_rate = 1.0
config.transforms.label_keys=["pca", "prostate"]
config.transforms.crop_key = "prostate"
config.transforms.image_keys = ["image"]
# config.transforms.image_keys = ["t2w", "adc", "hbv"]
config.num_workers = 90

label_key = config.transforms.label_keys[0]


In [3]:
weights_folder = "../../../gc_algorithms/base_container/models/swin_unetr/weights/"
models = []

for i in range(5):
    models.append(LitModel.load_from_checkpoint(f"{weights_folder}f{i}.ckpt", config=config, map_location=f"cuda:{gpu}"))
    # disable randomness, dropout, etc...
    models[-1].eval()
    models[-1].to(gpu)

monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().


In [4]:
dm = DataModule(
    config=config,
)

dm.setup("test")
dl = dm.test_dataloader()

Loading dataset: 100%|██████████| 1499/1499 [05:43<00:00,  4.37it/s]


In [5]:
prob_map_post_transforms = get_post_transforms(key="prob", 
                    orig_key=label_key,
                    orig_transforms=dm.transforms["test"],
                    out_dir=f"output/prob/",
                    keep_n_largest_components=0,
                    output_postfix="",
                    output_dtype="float32",
                    save_mask=SAVE_PROB_MAPS) 

pca_post_transforms = get_post_transforms(key="pca", 
                    orig_key=label_key,
                    orig_transforms=dm.transforms["test"],
                    out_dir=f"output/prob/",
                    keep_n_largest_components=0,
                    output_postfix="",
                    output_dtype="float32",
                    save_mask=SAVE_PREDS) 

In [6]:
picai_metric_fn = PicaiMetric()

all_probs = []
all_gts = []


for batch in tqdm(dl):
    with torch.no_grad():
        x = batch["image"].to(gpu)

        preds = []
        probs = []
        
        for fold, model in enumerate(models):
            logits = model.inferer(x)
            probs.append(torch.sigmoid(logits[0,1])[None][None])
            preds.append((probs[-1] > 0.5).float())
            
        
    batch["pred"] = (torch.mean(torch.stack(preds), dim=0) > 0.5).float()
    batch["prob"] = torch.mean(torch.stack(probs), dim=0)
    
    
    # Reverts back to original size
    batch["prob"] = [prob_map_post_transforms(i)["prob"] for i in decollate_batch(batch)]
    batch["pca"] = [pca_post_transforms(i)["pca"] for i in decollate_batch(batch)]
    
    all_probs.append(batch["prob"][0][0,...].cpu().numpy())
    all_gts.append(batch["pca"][0][0,...].cpu().numpy())
    

100%|██████████| 1499/1499 [15:29<00:00,  1.61it/s]


In [8]:
import picai_eval
from report_guided_annotation import extract_lesion_candidates

metrics = picai_eval.evaluate(
            y_det=all_probs,
            y_true=all_gts,
            y_det_postprocess_func=lambda pred: extract_lesion_candidates(pred, threshold="dynamic")[0],
            y_true_postprocess_func=lambda y: y,
            num_parallel_calls=16
        )

metrics
metrics.save("results/metrics.json")

In [None]:
plot_metrics(metrics,56)

In [None]:
plot_confusion(metrics, 56, threshold=0.5)