In [1]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRangePercentiles,
    ScaleIntensityRange,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    EnsureTyped,
    NormalizeIntensityd,
)
from monai.handlers.utils import from_engine
from monai.losses import GeneralizedDiceLoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from tqdm import tqdm
import time

print_config()

MONAI version: 1.4.dev2415
Numpy version: 1.26.4
Pytorch version: 2.2.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 189d1865c1b5b228b9d9e5e95ed40969eda7badc
MONAI __file__: C:\Users\<username>\.conda\envs\nnUNet\Lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.12.0
Pillow version: 10.2.0
Tensorboard version: 2.16.2
gdown version: 4.7.3
TorchVision version: 0.17.1
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, pleas

In [6]:
directory = "demo_data2"
test_images = sorted(glob.glob(os.path.join(directory, "imagesTs", "*.nii.gz")))
test_data = [{"image": image} for image in test_images]
test_data

[{'image': 'demo_data2\\imagesTs\\case_00210_0000.nii.gz'}]

In [8]:
keys = "image"
device = "cuda" if torch.cuda.is_available() else "cpu"

test_org_transforms = Compose(
    [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
        EnsureTyped(keys=keys),
        Spacingd(keys=keys, pixdim=(1, 0.78, 0.78), mode="bilinear"),
        Orientationd(keys=keys, axcodes="RAS"),
        NormalizeIntensityd(keys="image"),
        CropForegroundd(keys=keys, source_key="image"),
    ]
)
test_org_ds = Dataset(data=test_data, transform=test_org_transforms)

test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=4)
post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=test_org_transforms,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True, to_onehot=3),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
    ]
)

In [9]:
model = UNet(
    spatial_dims=3,  # 3D 图像分割 - 所以是 3
    in_channels=1,  # 输入通道数 
    out_channels=3,  # 包括背景 有 3 个类别
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

checkpoint = torch.load(os.path.join(directory, "best_metric_model.pth"))
model.load_state_dict(checkpoint)
model.eval()

with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs = test_data["image"].to(device)
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
        test_data = [post_transforms(i) for i in decollate_batch(test_data)]
        
        #

2024-04-25 01:22:55,912 INFO image_writer.py:197 - writing: out\case_00210_0000\case_00210_0000_seg.nii.gz
