In [1]:
import logging
import os
import sys
import tempfile
from glob import glob

import torch
from PIL import Image
from torch.utils.data import DataLoader

from monai import config
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, LoadImage, SaveImage, ScaleIntensity, EnsureType


In [2]:
tempdir = 'tmp_evaluation'
config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

print(f"generating synthetic data to {tempdir} (this may take a while)")
for i in range(5):
    im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
    Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
    Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

images = sorted(glob(os.path.join(tempdir, "img*.png")))
segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

# define transforms for image and segmentation
imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
val_ds = ArrayDataset(images, imtrans, segs, segtrans)
# sliding window inference for one image at every iteration
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

MONAI version: 0.7.0+43.g7b1b772a
Numpy version: 1.21.2
Pytorch version: 1.10.0a0+3fd9dcf
MONAI flags: HAS_EXT = True, USE_COMPILED = False
MONAI rev id: 7b1b772a4ad30c259696001a1a2380c52adffb65

Optional dependencies:
Pytorch Ignite version: 0.4.6
Nibabel version: 3.2.1
scikit-image version: 0.18.3
Pillow version: 8.2.0
Tensorboard version: 2.6.0
gdown version: 4.0.2
TorchVision version: 0.11.0a0
tqdm version: 4.62.1
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.3.3
einops version: 0.3.2
transformers version: 4.11.3
mlflow version: 1.20.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

generating synthetic data to tmp_evaluation (this may take a while)


In [3]:
model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
#model.eval()

<All keys matched successfully>

In [4]:
with torch.no_grad():
    for val_data in val_loader:
        val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
        val_labels = decollate_batch(val_labels)
        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        for val_output in val_outputs:
            saver(val_output)
    # aggregate the final mean dice result
    print("evaluation metric:", dice_metric.aggregate().item())
    # reset the status
    dice_metric.reset()

file written: /opt/monai/mnt/PyTorch/NoteBooks/GettingStarted/output/0/0_seg.png.
file written: /opt/monai/mnt/PyTorch/NoteBooks/GettingStarted/output/1/1_seg.png.
file written: /opt/monai/mnt/PyTorch/NoteBooks/GettingStarted/output/2/2_seg.png.
file written: /opt/monai/mnt/PyTorch/NoteBooks/GettingStarted/output/3/3_seg.png.
file written: /opt/monai/mnt/PyTorch/NoteBooks/GettingStarted/output/4/4_seg.png.
evaluation metric: 0.9880849719047546
