In [1]:
import albumentations as A 
import cv2 
import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn

from importlib import import_module
from skp.configs import Config
from skp.toolbox.functions import plot_3d_image_side_by_side

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [None]:
cfg = Config()
cfg.dataset = "totalclassifier.seg_3d"
cfg.annotations_file = "/home/ian/datasets/totalsegmentator/train_val_organ_classification_original_splits.csv"
cfg.data_dir = "/home/ian/datasets/totalsegmentator/pngs_for_slice_organ_classification/"
cfg.seg_data_dir = "/home/ian/datasets/totalsegmentator/segs_for_slice_organ_classification/"
cfg.cv2_load_flag = cv2.IMREAD_UNCHANGED
cfg.num_workers = 1
cfg.pin_memory = True
cfg.num_slices = 128 
cfg.image_height = 224
cfg.image_width = 224
additional_targets = {f"image{idx}": "image" for idx in range(1, cfg.num_slices)}
additional_targets.update({f"mask{idx}": "mask" for idx in range(1, cfg.num_slices)})
cfg.train_transforms = A.Compose([
    A.RandomScale(scale_limit=(-0.5, 1.5), p=1),
    A.PadIfNeeded(min_height=cfg.image_height, min_width=cfg.image_width, border_mode=cv2.BORDER_CONSTANT, p=1),
    A.RandomCrop(height=cfg.image_height, width=cfg.image_width, p=1),
    A.VerticalFlip(p=0.5),
    A.Transpose(p=0.5),
], additional_targets=additional_targets)
cfg.val_transforms = None
cfg.use_4channels = True
cfg.classes_subset = "class_map_part_organs"
cfg.fold = 0

In [3]:
dataset = import_module(f"skp.datasets.{cfg.dataset}").Dataset(cfg, mode="val")

[        adrenal_gland_left_label  adrenal_gland_right_label  aorta_label  \
742533                         0                          0            0   
742534                         0                          0            0   
742535                         0                          0            0   
742536                         0                          0            0   
742537                         0                          0            0   
...                          ...                        ...          ...   
742707                         0                          0            0   
742708                         0                          0            0   
742709                         0                          0            0   
742710                         0                          0            0   
742711                         0                          0            0   

        atrial_appendage_left_label  autochthon_left_label  \
742533                  

In [4]:
print(len(dataset))

171


In [4]:
for k, v in dataset.label_map.items():
    if v != 0:
        print(k)

1
2
13
17
18
21
40
41
42
43
44
45
46
47
48
49
50
52
82
84
86
90
91
92


In [5]:
idx = np.random.randint(len(dataset))
batch = dataset[idx]

TypeError: '>' not supported between instances of 'int' and 'NoneType'

In [None]:
batch["x"].shape, batch["y"].shape

NameError: name 'batch' is not defined

In [6]:
x_npy, y_npy = batch["x"].numpy().transpose(0, 2, 3, 1), batch["y"].numpy()
print(np.unique(y_npy))
plot_3d_image_side_by_side(x_npy[..., 0], y_npy, num_images=24, axis=0)

NameError: name 'batch' is not defined

In [19]:
cfg = Config()
cfg.model = "segmentation.base_3d"
cfg.backbone = "x3d_l"
cfg.decoder_type = "Unet3dDecoder"
cfg.num_input_channels = 4
cfg.enable_gradient_checkpointing = True 
cfg.dim0_strides = [2, 2, 2 ,2, 2]
cfg.decoder_n_blocks = 5
cfg.decoder_out_channels = [256, 128, 64, 32, 16]
cfg.decoder_norm_layer = "bn"
cfg.decoder_attention_type = None
cfg.decoder_separable_conv = True
cfg.num_classes = 24

cfg.num_slices = 128
cfg.image_height = 224
cfg.image_width = 224

model = import_module(f"skp.models.{cfg.model}").Net(cfg)

y = model({"x": batch["x"].permute(1, 0, 2, 3).unsqueeze(0)})
print(model)
print(y["logits"].shape)

Enabling gradient checkpointing ...


Net(
  (encoder): X3DEncoder(
    (features): ModuleList(
      (0): ResNetBasicStem(
        (conv): Conv2plus1d(
          (conv_t): Conv3d(4, 24, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
          (conv_xy): Conv3d(24, 24, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), groups=24, bias=False)
        )
        (norm): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (1): ResStage(
        (res_blocks): ModuleList(
          (0): ResBlock(
            (branch1_conv): Conv3d(24, 24, kernel_size=(1, 1, 1), stride=(2, 2, 2), bias=False)
            (branch2): BottleneckBlock(
              (conv_a): Conv3d(24, 54, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
              (norm_a): BatchNorm3d(54, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (act_a): ReLU()
              (conv_b): Conv3d(54, 54, kernel_size=(3, 3, 3), stride=(2, 2,

In [21]:
class ModelWrapper(nn.Module):
    # For MONAI sliding window inference function
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        return self.model({"x": x})["logits"]

In [28]:
from monai.inferers.utils import sliding_window_inference

wrapped_model = ModelWrapper(model)
test_input = torch.nn.functional.interpolate(batch["x"].permute(1, 0, 2, 3).unsqueeze(0), (128, 224, 224))
out = sliding_window_inference(
    inputs=test_input,
    roi_size=(128, 224, 224),
    sw_batch_size=2, 
    predictor=wrapped_model,
    overlap=0.0,
)
print(out.shape)

torch.Size([1, 24, 128, 224, 224])


In [None]:
import torch 
from skp.models.encoders3d import get_encoder
from skp.models.segmentation.decoders.convunext_3d import DecoderBlock
from skp.models.utils import change_num_input_channels
from skp.toolbox.functions import count_parameters

encoder_cfg = Config()
encoder_cfg.dim0_strides = [2, 2, 2, 2, 2]
encoder_cfg.backbone = "x3d_l"
encoder_cfg.enable_gradient_checkpointing = True
encoder = get_encoder(encoder_cfg)
change_num_input_channels(encoder, 4)
y = encoder(batch["x"].permute(1, 0, 2, 3).unsqueeze(0))
print([_.shape for _ in y])
count_parameters(encoder)

In [None]:
from skp.models.segmentation.decoders.unet_3d import Unet3dDecoder
from skp.models.segmentation.decoders.unext_3d import UneXt3dDecoder

decoder_cfg = Config()
decoder_cfg.decoder_n_blocks = 5
decoder_cfg.decoder_out_channels = [256, 128, 64, 32, 16]
# decoder_cfg.decoder_center_block = False
decoder_cfg.decoder_norm_layer = "bn"
decoder_cfg.decoder_attention_type = None
# decoder_cfg.decoder_separable_conv = True 
decoder_cfg.decoder_single_block = False
decoder_cfg.decoder_use_res_conv = False
decoder_cfg.encoder_channels = [24, 24, 48, 96, 192]
decoder = UneXt3dDecoder(decoder_cfg)
out = decoder(y)
count_parameters(decoder)

In [None]:
from skp.models.segmentation.decoders.unet_3d import Unet3dDecoder
from skp.models.segmentation.decoders.unext_3d import UneXt3dDecoder

decoder_cfg = Config()
decoder_cfg.decoder_n_blocks = 5
decoder_cfg.decoder_out_channels = [512, 256, 128, 64, 32]
decoder_cfg.decoder_center_block = False
decoder_cfg.decoder_norm_layer = "bn"
decoder_cfg.decoder_attention_type = None
decoder_cfg.decoder_separable_conv = True 
# decoder_cfg.decoder_use_res_conv = False
decoder_cfg.encoder_channels = [24, 24, 48, 96, 192]
decoder = Unet3dDecoder(decoder_cfg)
out = decoder(y)
count_parameters(decoder)

In [None]:
print([_.shape for _ in out])

In [None]:
class SegModel(nn.Module):

    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
from monai.inferiors.utils import sliding_window_inference


