In [2]:
import os
import numpy as np
import torch
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd,
    ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld, ToTensord,
    ScaleIntensityd,SpatialPadd,
    LambdaD,
    ConcatItemsd,
    AsDiscrete,
    NormalizeIntensityd,
    GaussianSmoothd
)
from monai.data import CacheDataset, DataLoader, Dataset
from monai.config import print_config
from monai.networks.nets import UNETR
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.metrics import DiceMetric

print_config()


MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.2.2+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/<username>/.conda/envs/unetSSL/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.12.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.1
einops version: 0.8.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: 1.0.0
clearml version: NOT INSTALLED or UNKNOWN VERSION.


In [3]:
# printing the results
device = torch.device("cpu")

#test dir
test_dir = "../Data/VoiceUsers/Val/Nasal25/"

image_dim = 64

def threshold_image(image):
    # Threshold the image: set values below 0.08 to 0
    return np.where(image < 0.2, 0, image)

def binarize_label(label):
    return (label > 0).astype(label.dtype)

# Gather test dataset
val_nrrd_files = sorted([os.path.join(test_dir, f) for f in os.listdir(
    test_dir) if f.endswith(".nrrd") and not f.endswith(".seg.nrrd")])
val_seg_nrrd_files = sorted([os.path.join(test_dir, f)
                            for f in os.listdir(test_dir) if f.endswith(".seg.nrrd")])

#Test datalist
test_datalist = [{"image": img, "label": lbl}
                       for img, lbl in zip(val_nrrd_files, val_seg_nrrd_files)]

# Test transforms
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    LambdaD(keys=["image"], func=threshold_image),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=False),
    ScaleIntensityd(keys=["image"], minv=0, maxv=1),
    LambdaD(keys=["label"], func=binarize_label),
    CropForegroundd(keys=["image", "label"], source_key="image"),
    SpatialPadd(keys=["image", "label"], spatial_size=(image_dim, image_dim, image_dim)),
    ToTensord(keys=["image", "label"]),
])

#Test loaders
test_ds = CacheDataset(data=test_datalist, transform=val_transforms,
                      cache_num=6, cache_rate=1.0, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=5,
                        shuffle=False, num_workers=4, pin_memory=True)

print(f"Training datalist setup: {test_datalist[0]}")

Loading dataset:   0%|          | 0/5 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 5/5 [00:00<00:00,  7.24it/s]

Training datalist setup: {'image': '../Data/VoiceUsers/Val/Nasal25/11 3DGRE_I FEET 3D.nrrd', 'label': '../Data/VoiceUsers/Val/Nasal25/11 3DGRE_I FEET 3D.seg.nrrd'}





In [4]:
print(f"length of the test samples {len(test_datalist)}")

length of the test samples 5


In [5]:
import tqdm
# Define the model (for example, UNet)
model = UNETR(
    in_channels=1,
    out_channels=1,
    img_size=(image_dim, image_dim, image_dim),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="conv",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.1,
)
model = model.to(device)

# Load pre-trained model weights (modify the path as needed)
# model.load_state_dict(torch.load(
#     "./logs/FineTune/FrenchSpeaker-Thresh-255-10000St.pth"))

########################################## Loading the MLFlow models
model_path = "../logs/FinalSSL/BetterResult1.pth" #Be careful about the path.

# Load the model using the correct path
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
####################################




UNETR(
  (vit): ViT(
    (patch_embedding): PatchEmbeddingBlock(
      (patch_embeddings): Conv3d(1, 768, kernel_size=(16, 16, 16), stride=(16, 16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (blocks): ModuleList(
      (0-11): 12 x TransformerBlock(
        (mlp): MLPBlock(
          (linear1): Linear(in_features=768, out_features=3072, bias=True)
          (linear2): Linear(in_features=3072, out_features=768, bias=True)
          (fn): GELU(approximate='none')
          (drop1): Dropout(p=0.1, inplace=False)
          (drop2): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): SABlock(
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (qkv): Linear(in_features=768, out_features=2304, bias=False)
          (input_rearrange): Rearrange('b h (qkv l d) -> qkv b l h d', qkv=3, l=12)
          (out_rearrange): Rearrange('b h l d -> b l (h d)')
          (drop_out

In [6]:
from tqdm import tqdm

# Iterate through the test loader with a progress bar
epoch_iterator = tqdm(test_loader)

for step, batch in enumerate(epoch_iterator, start=1):  # Start enumeration at 1
    test_input = batch["image"].to(device)  # Move the image to the specified device
    label_input = batch["label"].to(device)
    #test_outputs = sliding_window_inference(test_input, (64, 64, 64), 4, model)

    overlap_factor = 0  # Example overlay factor, adjust as needed
    test_outputs = sliding_window_inference(test_input, (image_dim, image_dim, image_dim), 4, model, overlap=overlap_factor)
    

100%|██████████| 1/1 [00:22<00:00, 22.26s/it]


In [17]:
image_num = 2

from monai.transforms import RemoveSmallObjects, FillHoles

# Define the transform
remove_small_segments = remove_small_segments = Compose([
    AsDiscrete(sigmoid=True),
    #AsDiscrete(argmax=True),
    RemoveSmallObjects(min_size=600, connectivity=1, independent_channels=True),
    FillHoles(applied_labels=None, connectivity=None)
])

# Apply the transform to your segmentation map
segmentation_map = remove_small_segments(test_outputs[image_num,:,:,:,:]).to(device)

RuntimeError: applying transform <monai.transforms.post.array.RemoveSmallObjects object at 0x7fe568354d50>

In [14]:
test_outputs.shape, test_input.shape, label_input.shape, segmentation_map.shape

(torch.Size([5, 2, 256, 256, 64]),
 torch.Size([5, 1, 256, 256, 64]),
 torch.Size([5, 1, 256, 256, 64]),
 torch.Size([2, 256, 256, 64]))

In [15]:
print(np.min(test_outputs), np.min(label_input))

-22.991959 0.0


In [16]:
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

# Create a slider to select which slice to display
slice_slider = widgets.IntSlider(value=0, min=0, max=60, description='Slice')


# Function to update the output based on the slider value
def update_slice(slice_index):
    clear_output(wait=True)  # Clear previous output
    
    # Create a figure with a single row of three subplots
    plt.figure(figsize=(10, 10))  # Adjust the figure size to prevent rows from stacking

    plt.subplot(1, 3, 1)  # Create a subplot for the slice
    plt.imshow(segmentation_map.detach().numpy()[0, :, :, slice_index] > 0.1, cmap='gray')
    plt.title(f'Slice {slice_index}')
    plt.axis('off')

    plt.subplot(1, 3, 2)  # Create a subplot for the label input visualization
    plt.imshow(label_input[image_num, 0, :, :, slice_index], cmap='gray')
    plt.title('Ref')
    plt.axis('off')

    plt.subplot(1, 3, 3)  # Create a subplot for the test input visualization
    plt.imshow(test_input[image_num, 0, :, :, slice_index], cmap='gray')
    plt.title('Test Output')
    plt.axis('off')

    plt.show()

# Display the slider and the initial plot
widgets.interactive(update_slice, slice_index=slice_slider)

interactive(children=(IntSlider(value=0, description='Slice', max=60), Output()), _dom_classes=('widget-intera…