In [2]:
from monai.transforms import LoadImage
from monai.transforms import (
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference
from monai.data import DataLoader, Dataset, decollate_batch

import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import os
import glob
import numpy as np

#print_config()

In [3]:
# Load test data
data_dir = "../data"
test_images = sorted(glob.glob(os.path.join(data_dir, "imagesTs", "*.nii.gz")))

test_data = [{"image": image} for image in test_images]

test_org_transforms = Compose(
    [
        LoadImaged(keys="image"),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image", allow_smaller=True),
    ]
)

test_org_ds = Dataset(data=test_data, transform=test_org_transforms)

test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=4)

post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=test_org_transforms,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True, to_onehot=2),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
    ]
)

In [4]:

loader = LoadImage()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm="batch",
)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load("best_metric_model.pth", weights_only=True, map_location=device))
model.eval()

DataParallel(
  (module): UNet(
    (model): Sequential(
      (0): ResidualUnit(
        (conv): Sequential(
          (unit0): Convolution(
            (conv): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
            (adn): ADN(
              (N): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (D): Dropout(p=0.0, inplace=False)
              (A): PReLU(num_parameters=1)
            )
          )
          (unit1): Convolution(
            (conv): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (adn): ADN(
              (N): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (D): Dropout(p=0.0, inplace=False)
              (A): PReLU(num_parameters=1)
            )
          )
        )
        (residual): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      )
      (1): SkipConnection(
        (submodu

In [7]:
from random import randint
nb = randint(0, len(test_org_ds) - 1)

# Get the image and apply any necessary preprocessing
test_item = next(iter(test_org_loader))
test_inputs = test_item["image"].to(device)  # Add batch dim

# Sliding window inference
roi_size = (160, 160, 160)
sw_batch_size = 4
with torch.no_grad():
    test_item["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)

test_result = [post_transforms(i) for i in decollate_batch(test_item)]
test_output = from_engine(["pred"])(test_result)



2025-08-14 22:51:54,240 INFO image_writer.py:197 - writing: out/spleen_1/spleen_1_seg.nii.gz


In [8]:
print(test_output[0].shape)

torch.Size([2, 512, 512, 34])


In [None]:
from monai.visualize.utils import blend_images, matshow3d

original_image = loader(test_output[0].meta["filename_or_obj"])
image = original_image.unsqueeze(0)  # Ensure image has a channel dimension
label = test_output[0].detach().cpu()[1, :, :, :]
label = label.unsqueeze(0)  # Ensure label has a channel dimension
img = blend_images(image, label, cmap="gray")
print(img.shape)

for i in range(img.shape[-1]):
    plt.figure()
    plt.imshow(img[1, :, :, i], cmap="gray")
    plt.axis("off")
    plt.title(f"Slice {i}")
    plt.show()

In [10]:
import torch
import numpy as np
import nibabel as nib

# Convert to NumPy and rearrange axes to (X, Y, Z, T)
# From (3, 512, 512, 34) to (512, 512, 34, 3)
nifti_array = img.permute(1, 2, 3, 0).numpy()

# Create affine (identity matrix for no transformation)
affine = np.eye(4)

# Create NIfTI image
nifti_img = nib.Nifti1Image(nifti_array, affine)

# Save the NIfTI file
nib.save(nifti_img, '../results/output_image2.nii.gz')

print("✅ NIfTI file saved as 'output_image.nii.gz'")


✅ NIfTI file saved as 'output_image.nii.gz'


In [11]:
import torch
import numpy as np
import pydicom
from pydicom.dataset import Dataset, FileDataset
import datetime
import os


# Convert to numpy
volume = img.numpy()

# Make sure output directory exists
output_dir = "../results/dicom2"
os.makedirs(output_dir, exist_ok=True)

# Save each slice (34 total) as a DICOM file
for i in range(volume.shape[-1]):
    for channel in range(volume.shape[0]):
        slice_data = volume[channel, :, :, i]

        # Normalize to 0–255 and convert to uint8
        slice_norm = (slice_data - slice_data.min()) / (slice_data.max() - slice_data.min())
        pixel_array = (slice_norm * 255).astype(np.uint8)

        # Create DICOM metadata
        file_meta = Dataset()
        file_meta.MediaStorageSOPClassUID = pydicom.uid.SecondaryCaptureImageStorage
        file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
        file_meta.ImplementationClassUID = pydicom.uid.generate_uid()

        # Create FileDataset
        ds = FileDataset(f"", {}, file_meta=file_meta, preamble=b"\0" * 128)
        ds.PatientName = "Test^Patient"
        ds.PatientID = "123456"
        ds.Modality = "OT"  # Other
        ds.SeriesInstanceUID = pydicom.uid.generate_uid()
        ds.StudyInstanceUID = pydicom.uid.generate_uid()
        ds.SOPInstanceUID = file_meta.MediaStorageSOPInstanceUID
        ds.SOPClassUID = file_meta.MediaStorageSOPClassUID

        dt = datetime.datetime.now()
        ds.ContentDate = dt.strftime('%Y%m%d')
        ds.ContentTime = dt.strftime('%H%M%S')

        ds.Rows, ds.Columns = pixel_array.shape
        ds.SamplesPerPixel = 1
        ds.PhotometricInterpretation = "MONOCHROME2"
        ds.BitsAllocated = 8
        ds.BitsStored = 8
        ds.HighBit = 7
        ds.PixelRepresentation = 0
        ds.PixelData = pixel_array.tobytes()

        filename = os.path.join(output_dir, f"slice_{i:03d}_channel_{channel}.dcm")
        ds.save_as(filename)

print(f"DICOM files saved to '{output_dir}'")


DICOM files saved to '../results/dicom2'
