In [1]:
import torch
from system import System
from mamba_mic.data_modules.hnts_mrg import HNTSMRGDataModule
from monai.inferers import sliding_window_inference
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.transforms import (
    Compose,
    Activations,
    AsDiscreted,
    Invertd,
    KeepLargestConnectedComponentd,
    Lambdad,
)
from lightning.pytorch import seed_everything
from tqdm import tqdm
import re
import nibabel as nib
import numpy as np

seed_everything(42)

[rank: 0] Seed set to 42


42

In [2]:
task = "midRT"
run_id = "fdsi3lou"
checkpoint = "model-epoch=219-val_loss=0.46"
checkpoint_path = f"lightning_logs/{run_id}/checkpoints/{checkpoint}.ckpt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = System.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()
model.to(device)

/cluster/home/eriksalv/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.


System(
  (net): DynUNet(
    (input_block): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(4, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
      )
      (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
      (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    )
    (downsamples): ModuleList(
      (0): UnetBasicBlock(
        (conv1): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 1), padding=(1, 1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
        (lrelu): LeakyReLU(negative_slope=0.01, in

In [3]:
data_module = HNTSMRGDataModule(batch_size=1, task=task)
data_module.prepare_data()
data_module.setup()
val_set = data_module.val_set

In [4]:
dice = DiceMetric(include_background=False, reduction="mean_batch")
hd95 = HausdorffDistanceMetric(
    distance_metric="euclidean",
    percentile=95,
    include_background=False,
    reduction="mean_batch",
)

post_transforms = Compose(
    [
        AsDiscreted(keys="pred", argmax=True, dim=0, to_onehot=3),
        Invertd(
            keys="pred",
            transform=val_set.transform,
            orig_keys="label",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
        Invertd(
            keys="label",
            transform=val_set.transform,
            orig_keys="label",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
        KeepLargestConnectedComponentd(
            keys="pred", applied_labels=[1], is_onehot=True, connectivity=1
        ),
    ]
)

In [5]:
gtvp = []
gtvn = []

model.eval()
with torch.no_grad():
    for val_data in tqdm(val_set, total=20):
        path = val_data["label"].meta["filename_or_obj"]
        val_number = re.search("[0-9]+", path).group(0)

        x, y = (
            val_data["image"].to(device).unsqueeze(0),
            val_data["label"].to(device).unsqueeze(0),
        )
        val_data["pred"] = sliding_window_inference(
            x, roi_size=[192, 192, 48], overlap=0.5, sw_batch_size=2, predictor=model
        ).squeeze(0)

        postprocessed = post_transforms(val_data)
        y_pred, y = postprocessed["pred"], postprocessed["label"]
        dice(y_pred=y_pred.unsqueeze(0), y=y.unsqueeze(0))

        dices = dice.aggregate()
        dice.reset()

        gtvp.append(dices[0].item())
        gtvn.append(dices[1].item())

        nib.save(
            nib.Nifti1Image(
                torch.argmax(y_pred, dim=0).type(torch.float).numpy(),
                affine=val_data["image"].meta["original_affine"],
            ),
            f"./data/HNTS-MRG/{task}/{run_id}_{checkpoint}_{val_number}_pred.nii.gz",
        )

100%|██████████| 20/20 [15:08<00:00, 45.42s/it]


In [6]:
print(np.array(gtvp).mean())
print(np.array(gtvn).mean())

0.2861071038991213
0.6310926109552384
