In [None]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.losses import DiceLoss
from monai.utils import set_determinism
import onnxruntime
from tqdm import tqdm
import random
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    ResizeWithPadOrCropd,
    EnsureChannelFirstd,
    RandZoomd
)
from monai.data import DataLoader, CacheDataset
import numpy as np
import torch.nn.functional as F
from pathlib import Path


if 'MASTER_ADDR' not in os.environ:
    os.environ['MASTER_ADDR'] = '127.0.0.1'
if 'MASTER_PORT' not in os.environ:
    os.environ['MASTER_PORT'] = '29500'

import torch
import time
from model_builder import UNet3D
from tomultichannel import ConvertToMultiChannel
from preprocess import preprocess_data, show_image
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from glob import glob
import nibabel as nib
from RepeatChannel import RepeatChannelsd
#from monai.data import NiftiSaver
import numpy as np
from U_Mamba_net import U_Mamba_net
from monai.transforms import AsDiscrete, Activations, Compose, Resize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_determinism(seed=0)

NUM_WORKERS = 1
#MODEL_PATH = r"model/Medical_Image_UNet3D.pth"
#MODEL_PATH = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/going_modular/model/Medical_Image_U_Mamba_Net_ssm_16_3D.pth"
MODEL_PATH = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/going_modular/model/Medical_Image_U_Mamba_Net_ssm_8_3D_add_learnable_weight.pth"
BASE_DIR_LINUX = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/data/test_data/"
#BASE_DIR_LINUX = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/data/Task01_BrainTumour/imagesVal"
torch.cuda.empty_cache()

def preprocess_val():
    #val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "imagesVal", "*.nii.gz")))
    val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "*.nii.gz")))
    val_data = [{"image": img} for img in val_images]
    val_transform = Compose([
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image"]),
        Orientationd(keys="image", axcodes="RAS"),
        Spacingd(keys="image", pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ResizeWithPadOrCropd(keys="image", spatial_size=(128, 128, 64)), #also try (128, 128, 64)
        RepeatChannelsd(keys=["image"], target_channels=4),
        EnsureTyped(keys=["image"]),
    ])
    val_ds = CacheDataset(
        data=val_data,
        transform=val_transform,
        cache_rate=0.1,
        num_workers=NUM_WORKERS
    )
    return val_ds

def predict(image_name):
    #model = UNet3D(in_channels=4, out_channels=3).to(device)
    model = U_Mamba_net(in_channels=4, num_classes=3).to(device)
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    state_dict = checkpoint["model_state_dict"]

    #searh image in directory
    val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "*.nii.gz")))
    matched_images = [img for img in val_images if os.path.basename(img) == image_name]
    if not matched_images:
        raise FileNotFoundError(f"No image found with name {image_name}.nii.gz in {BASE_DIR_LINUX}")
    image_path = matched_images[0]
    image_location = val_images.index(image_path)
    image_basename = os.path.splitext(os.path.basename(image_path))[0]
    basename = os.path.basename(image_path)
    if basename.endswith(".nii.gz"):
        image_basename = basename.replace(".nii.gz", "")
    else:
        image_basename = os.path.splitext(basename)[0]
    
    # --- prepare dataset ---
    val_ds = preprocess_val()

    # load state dict (strip 'module.' if present)
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
        out = model(val_input)
        if isinstance(out, tuple):
            out = out[0]
        probs = torch.sigmoid(out)  # [1, 3, 128, 128, 64]

        # Upsample probs, then threshold
        target_size = (240, 240, 155)
        probs_up = F.interpolate(probs, size=target_size, mode="trilinear", align_corners=False)
        union_mask = (probs_up > 0.5).any(dim=1).squeeze(0)  # [H,W,D] boolean

    # Save one file using the source affine
    src = nib.load(image_path)
    affine = src.affine
    mask_np = union_mask.cpu().numpy().astype(np.uint8)

Path("predictions").mkdir(parents=True, exist_ok=True)
save_path = os.path.join("predictions", f"{image_basename}_tumor_union.nii.gz")
nib.save(nib.Nifti1Image(mask_np, affine), save_path)
print(f"[INFO] Saved {save_path}")

if __name__ == "__main__":
    image_name = input("Enter the image file name: ")
    if not image_name.endswith(".nii.gz"):
        image_name += ".nii.gz"
    # Check if the image file exists
    while not os.path.isfile(os.path.join(BASE_DIR_LINUX, image_name)):
        print(f"Image file {image_name} not found in {BASE_DIR_LINUX}. Please enter a valid file name.")
        image_name = input("Enter the image file name: ")
        if not image_name.endswith(".nii.gz"):
            image_name += ".nii.gz"
    predict(image_name)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Val input shape: torch.Size([1, 4, 128, 128, 64])
Upsampled probs shape: torch.Size([1, 3, 240, 240, 155])
Binarized preds shape: torch.Size([3, 240, 240, 155])
[INFO] Saved predictions/BraTS-SSA-00009-000-t1c_class0.nii.gz
[INFO] Saved predictions/BraTS-SSA-00009-000-t1c_class1.nii.gz
[INFO] Saved predictions/BraTS-SSA-00009-000-t1c_class2.nii.gz


In [5]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.losses import DiceLoss
from monai.utils import set_determinism
import onnxruntime
from tqdm import tqdm
import random
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    ResizeWithPadOrCropd,
    EnsureChannelFirstd,
    RandZoomd
)
from monai.data import DataLoader, CacheDataset
import numpy as np
import torch.nn.functional as F
from pathlib import Path


if 'MASTER_ADDR' not in os.environ:
    os.environ['MASTER_ADDR'] = '127.0.0.1'
if 'MASTER_PORT' not in os.environ:
    os.environ['MASTER_PORT'] = '29500'

import torch
import time
from model_builder import UNet3D
from tomultichannel import ConvertToMultiChannel
from preprocess import preprocess_data, show_image
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from glob import glob
import nibabel as nib
from RepeatChannel import RepeatChannelsd
#from monai.data import NiftiSaver
import numpy as np
from U_Mamba_net import U_Mamba_net
from monai.transforms import AsDiscrete, Activations, Compose, Resize
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_determinism(seed=0)

NUM_WORKERS = 1
#MODEL_PATH = r"model/Medical_Image_UNet3D.pth"
#MODEL_PATH = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/going_modular/model/Medical_Image_U_Mamba_Net_ssm_16_3D.pth"
MODEL_PATH = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/going_modular/model/Medical_Image_U_Mamba_Net_ssm_8_3D_add_learnable_weight.pth"
BASE_DIR_LINUX = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/data/test_data/"
#BASE_DIR_LINUX = r"/home/luudh/luudh/MyFile/medical_image_lab/monai/data/Task01_BrainTumour/imagesVal"
torch.cuda.empty_cache()

def preprocess_val():
    #val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "imagesVal", "*.nii.gz")))
    val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "*.nii.gz")))
    val_data = [{"image": img} for img in val_images]
    val_transform = Compose([
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image"]),
        Orientationd(keys="image", axcodes="RAS"),
        Spacingd(keys="image", pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ResizeWithPadOrCropd(keys="image", spatial_size=(128, 128, 64)), #also try (128, 128, 64)
        RepeatChannelsd(keys=["image"], target_channels=4),
        EnsureTyped(keys=["image"]),
    ])
    val_ds = CacheDataset(
        data=val_data,
        transform=val_transform,
        cache_rate=0.1,
        num_workers=NUM_WORKERS
    )
    return val_ds

def predict(image_name):
    # model
    model = U_Mamba_net(in_channels=4, num_classes=3).to(device)
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    state_dict = checkpoint["model_state_dict"]
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # find image
    val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "*.nii.gz")))
    matched_images = [img for img in val_images if os.path.basename(img) == image_name]
    if not matched_images:
        raise FileNotFoundError(f"No image found with name {image_name} in {BASE_DIR_LINUX}")
    image_path = matched_images[0]
    image_location = val_images.index(image_path)
    basename = os.path.basename(image_path)
    image_basename = basename[:-7] if basename.endswith(".nii.gz") else os.path.splitext(basename)[0]

    # original affine for correct spatial placement
    orig_affine = nib.load(image_path).affine

    # dataset / transforms
    val_ds = preprocess_val()  # -> (C, Z, Y, X) MetaTensor after your pipeline

    with torch.no_grad():
        # input
        val_img_meta = val_ds[image_location]
        val_input = val_img_meta["image"].unsqueeze(0).to(device)   # (1, 4, Z, Y, X)
        print(f"Val input shape: {val_input.shape}")                # (1, 4, Z, Y, X)

        # forward -> logits (no thresholding here)
        logits = model(val_input)                                   # (1, C, Z, Y, X)
        print(f"Logits shape: {logits.shape}")

        # softmax (mutually exclusive classes) then resize in Torch
        probs = torch.softmax(logits, dim=1)                        # (1, C, Z, Y, X)

        # target volume size; you previously used (240, 240, 155) (H, W, D)
        # PyTorch expects size=(D, H, W) -> (155, 240, 240)
        target_dhw = (155, 240, 240)
        resized = F.interpolate(probs, size=target_dhw, mode="nearest")  # (1, C, 155, 240, 240)
        print(f"Resized probs shape: {resized.shape}")

        # argmax to single-label mask
        mask = resized.argmax(dim=1)[0].to(torch.uint8).cpu().numpy()    # (155, 240, 240)

        # save nifti with original affine
        os.makedirs("predictions", exist_ok=True)
        save_path = os.path.join("predictions", f"{image_basename}_predicted.nii.gz")
        nib.save(nib.Nifti1Image(mask, affine=orig_affine), save_path)
        print(f"[INFO] Saved: {save_path}")

if __name__ == "__main__":
    image_name = input("Enter the image file name: ")
    if not image_name.endswith(".nii.gz"):
        image_name += ".nii.gz"
    # Check if the image file exists
    while not os.path.isfile(os.path.join(BASE_DIR_LINUX, image_name)):
        print(f"Image file {image_name} not found in {BASE_DIR_LINUX}. Please enter a valid file name.")
        image_name = input("Enter the image file name: ")
        if not image_name.endswith(".nii.gz"):
            image_name += ".nii.gz"
    predict(image_name)

Val input shape: torch.Size([1, 4, 128, 128, 64])


AttributeError: 'tuple' object has no attribute 'shape'

In [4]:
def predict(image_name):
    # model
    model = U_Mamba_net(in_channels=4, num_classes=3).to(device)
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    state_dict = checkpoint["model_state_dict"]
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # find image
    val_images = sorted(glob(os.path.join(BASE_DIR_LINUX, "*.nii.gz")))
    matched_images = [img for img in val_images if os.path.basename(img) == image_name]
    if not matched_images:
        raise FileNotFoundError(f"No image found with name {image_name} in {BASE_DIR_LINUX}")
    image_path = matched_images[0]
    image_location = val_images.index(image_path)
    basename = os.path.basename(image_path)
    image_basename = basename[:-7] if basename.endswith(".nii.gz") else os.path.splitext(basename)[0]

    # original affine for correct spatial placement
    orig_affine = nib.load(image_path).affine

    # dataset / transforms
    val_ds = preprocess_val()  # -> (C, Z, Y, X) MetaTensor after your pipeline

    with torch.no_grad():
        # input
        val_img_meta = val_ds[image_location]
        val_input = val_img_meta["image"].unsqueeze(0).to(device)   # (1, 4, Z, Y, X)
        print(f"Val input shape: {val_input.shape}")                # (1, 4, Z, Y, X)

        # forward -> logits (no thresholding here)
        logits = model(val_input)                                   # (1, C, Z, Y, X)
        print(f"Logits shape: {logits.shape}")

        # softmax (mutually exclusive classes) then resize in Torch
        probs = torch.softmax(logits, dim=1)                        # (1, C, Z, Y, X)

        # target volume size; you previously used (240, 240, 155) (H, W, D)
        # PyTorch expects size=(D, H, W) -> (155, 240, 240)
        target_dhw = (155, 240, 240)
        resized = F.interpolate(probs, size=target_dhw, mode="nearest")  # (1, C, 155, 240, 240)
        print(f"Resized probs shape: {resized.shape}")

        # argmax to single-label mask
        mask = resized.argmax(dim=1)[0].to(torch.uint8).cpu().numpy()    # (155, 240, 240)

        # save nifti with original affine
        os.makedirs("predictions", exist_ok=True)
        save_path = os.path.join("predictions", f"{image_basename}_predicted.nii.gz")
        nib.save(nib.Nifti1Image(mask, affine=orig_affine), save_path)
        print(f"[INFO] Saved: {save_path}")