In [2]:
import glob
import math
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import torch
from torch.optim import Adam, SGD
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import (
    CacheDataset,
    DataLoader,
    ThreadDataLoader,
    Dataset,
    decollate_batch,
    set_track_meta,
)
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.layers import Act, Norm
from monai.networks.nets import UNet, SwinUNETR, AHNet, VNet

from monai.transforms import (
    EnsureChannelFirstd,
    AsDiscrete,
    Compose,
    CropForegroundd,
    EnsureTyped,
    FgBgToIndicesd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
)
from monai.utils import set_determinism


from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"  # Use the 3rd and 4th GPU. Indexing starts from 0.

# for profiling
import nvtx
from monai.utils.nvtx import Range
import contextlib  # to improve code readability (combining training/validation loop with and without profiling)

#print_config()
set_determinism(0)

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")

current_directory = os.getcwd()

out_dir = os.path.join(current_directory,"outputs/")

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

root dir is: /tmp/tmpsda545t9


In [None]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import pydicom
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
import nibabel as nib
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


from timm.layers.conv2d_same import Conv2dSame

def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List


# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)


# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
    ih, iw, iz = x.size()[-3:]
    pad_h = get_same_padding(ih, k[0], s[0], d[0])
    pad_w = get_same_padding(iw, k[1], s[1], d[1])
    pad_z = get_same_padding(iz, k[2], s[2], d[2])
    if pad_h > 0 or pad_w > 0 or pad_z > 0:
        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2, pad_z // 2, pad_z - pad_z // 2], value=value)
    return x


def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
            if is_static_pad(kernel_size, **kwargs):
                # static case, no extra overhead
                padding = get_padding(kernel_size, **kwargs)
            else:
                # dynamic 'SAME' padding, has runtime/GPU memory overhead
                padding = 0
                dynamic = True
        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            padding = get_padding(kernel_size, **kwargs)
    return padding, dynamic


def conv3d_same(
        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int, int] = (1, 1, 1),
        padding: Tuple[int, int, int] = (0, 0, 0), dilation: Tuple[int, int, int] = (1, 1, 1), groups: int = 1):
    x = pad_same(x, weight.shape[-3:], stride, dilation)
    return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)


class Conv3dSame(nn.Conv3d):
    """ Tensorflow like 'SAME' convolution wrapper for 3d convolutions
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv3dSame, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
    else:
        return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

kernel_type = 'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
load_kernel = None
load_last = True
n_blocks = 4
n_folds = 5
backbone = 'resnet18d'

init_lr = 3e-3
batch_size = 4
drop_rate = 0.
drop_path_rate = 0.
loss_weights = [1, 1]
p_mixup = 0.1
drop_rate = 0.
drop_path_rate = 0.
out_dim = 6
        
import torch
import torch.nn as nn
import timm
import segmentation_models_pytorch as smp

class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False, num_classes=6):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=0.1,
            drop_path_rate=0.1,
            pretrained=pretrained
        )

        # Modify the first layer of the encoder for 1-channel input
        self.encoder.conv1 = nn.Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)

        self.decoder = smp.Unet(
            encoder_name=backbone,
            encoder_depth=5,
            encoder_weights=None,  # Weights are already loaded in the encoder part
            decoder_use_batchnorm=True,
            decoder_channels=[256, 128, 64, 32, 16],
            decoder_attention_type='scse',
            in_channels=1,
            classes=num_classes,
            activation=None,  # Use 'None' because we handle activation in the loss function (e.g., CrossEntropy)
        )

    def forward(self, x):
        return self.decoder(x)

In [None]:
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from scipy.signal import savgol_filter


In [None]:
def process_and_get_bounding_boxes_v6(tensor):
    """
    Compute bounding boxes, determine the voxel count per depth, and the inner region for each organ.
    Also computes a bounding box that encompasses all organs based on the inner_slices.
    
    Args:
    - tensor (numpy.ndarray): The input tensor with shape (D, H, W).
    
    Returns:
    - dict: Bounding boxes, voxel counts per depth, and inner regions for each organ.
            Also includes the encompassing bounding box for all organs.
    """
    
    # Initialize bounding boxes dictionary
    bounding_boxes = {}
    
    # List to store all coordinates within the inner_slices for all organs
    all_coords_in_inner_slices = []
    
    for organ_label, organ_name in SEGMENTATION_CODES.items():
        if organ_label == 0:  # Skip background
            continue
        
        # Find the slices where the organ appears
        slices = np.where(np.any(tensor == organ_label, axis=(1, 2)))[0]
        
        if len(slices) == 0:  # If organ does not appear, continue to next organ
            continue
        
        # Compute voxel count for each slice and store in a dictionary with slice index as key
        voxel_counts_per_depth = {slice_idx: np.sum(tensor[slice_idx] == organ_label) for slice_idx in slices}
        
        # Fourier-based smoothing and thresholding
        simple_list = list(voxel_counts_per_depth.values())
        smoothed_list = fourier_smoothing(simple_list, fourier_factor)
        
        # Thresholding logic
        max_value = np.max(smoothed_list)
        threshold = threshold_factor * max_value
        in_threshold = (smoothed_list >= threshold)
        
        # Extract slice indices that are in the inner region
        inner_slices = get_contiguous_inner_slices(in_threshold, slices)
        
        # Gather all coordinates for this organ that are within the inner_slices
        coords = np.argwhere((tensor == organ_label) & (np.isin(tensor, inner_slices)))
        all_coords_in_inner_slices.extend(coords)
        
        # Bounding box computation for this organ
        min_coords = coords.min(axis=0)
        max_coords = coords.max(axis=0)
        
        bounding_boxes[organ_name] = {
            "top_left_front": tuple(min_coords),
            "bottom_right_back": tuple(max_coords),
            "depth_range": (min_coords[0], max_coords[0]),
            "voxel_counts_per_depth": voxel_counts_per_depth,
            "inner_slices": inner_slices
        }
    
    # Compute the encompassing bounding box for all organs based on inner_slices
    all_coords_array = np.array(all_coords_in_inner_slices)
    min_encompassing_coords = all_coords_array.min(axis=0)
    max_encompassing_coords = all_coords_array.max(axis=0)
    bounding_boxes["encompassing"] = {
        "top_left_front": tuple(min_encompassing_coords),
        "bottom_right_back": tuple(max_encompassing_coords),
        "depth_range": (min_encompassing_coords[0], max_encompassing_coords[0])
    }
        
    return bounding_boxes

def fourier_smoothing(y, cutoff_fraction=0.1):
    # Compute the FFT of the input data
    y_fft = np.fft.fft(y)
    
    # Zero out the higher frequencies
    cutoff_idx = int(cutoff_fraction * len(y_fft))
    y_fft[cutoff_idx:-cutoff_idx] = 0
    
    # Compute the inverse FFT to get the smoothed data
    y_smoothed = np.fft.ifft(y_fft).real  # Only take the real part
    
    return y_smoothed
    
def get_contiguous_inner_slices(in_threshold, slices):
    """
    Get contiguous blocks of slices that are above the threshold.
    
    Args:
    - in_threshold (list): A boolean list indicating if the slice is above the threshold.
    - slices (list): List of slice indices.
    
    Returns:
    - list: Contiguous inner slices.
    """
    inner_slices = []
    start_slice = None

    for i, val in enumerate(in_threshold):
        if val:
            if start_slice is None:  # Start of a new block
                start_slice = slices[i]
        else:
            if start_slice is not None:  # End of the current block
                inner_slices.extend(range(start_slice, slices[i]))
                start_slice = None

    if start_slice is not None:  # For the case where the last block reaches the end
        inner_slices.extend(range(start_slice, slices[-1] + 1))

    return inner_slices

# Function to save numpy array as .nii.gz
def save_nii(data, filename):
    img = nib.Nifti1Image(data, affine=np.eye(4))
    nib.save(img, filename)

In [None]:
model_device = torch.device("cuda:1") 

# Create the model
seg_model = TimmSegModel(backbone = 'resnet18', segtype='unet', pretrained = True)
seg_model = convert_3d(seg_model)
# Test the model (just to ensure everything is working)
seg_model = seg_model.to(model_device)

seg_model.load_state_dict(torch.load(os.path.join(current_directory, "resnet18_unet_best_metric_model.pt")))
seg_model.eval()

In [None]:
# Transformation pipeline
test_org_transforms = Compose(
    [
        LoadImaged(keys="image"),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(2.0, 2.0, 2.0), mode="bilinear"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
    ]
)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


def squash(inputs, axis=-1):
    """
    The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0
    :param inputs: vectors to be squashed
    :param axis: the axis to squash
    :return: a Tensor with same size as inputs
    """
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = norm**2 / (1 + norm**2) / (norm + 1e-8)
    return scale * inputs


class DenseCapsule(nn.Module):
    """
    The dense capsule layer. It is similar to Dense (FC) layer. Dense layer has `in_num` inputs, each is a scalar, the
    output of the neuron from the former layer, and it has `out_num` output neurons. DenseCapsule just expands the
    output of the neuron from scalar to vector. So its input size = [None, in_num_caps, in_dim_caps] and output size = \
    [None, out_num_caps, out_dim_caps]. For Dense Layer, in_dim_caps = out_dim_caps = 1.

    :param in_num_caps: number of cpasules inputted to this layer
    :param in_dim_caps: dimension of input capsules
    :param out_num_caps: number of capsules outputted from this layer
    :param out_dim_caps: dimension of output capsules
    :param routings: number of iterations for the routing algorithm
    """
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DenseCapsule, self).__init__()
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        # x.size=[batch, in_num_caps, in_dim_caps]
        # expanded to    [batch, 1,            in_num_caps, in_dim_caps,  1]
        # weight.size   =[       out_num_caps, in_num_caps, out_dim_caps, in_dim_caps]
        # torch.matmul: [out_dim_caps, in_dim_caps] x [in_dim_caps, 1] -> [out_dim_caps, 1]
        # => x_hat.size =[batch, out_num_caps, in_num_caps, out_dim_caps]
        x_hat = torch.squeeze(torch.matmul(self.weight, x[:, None, :, :, None]), dim=-1)

        # In forward pass, `x_hat_detached` = `x_hat`;
        # In backward, no gradient can flow from `x_hat_detached` back to `x_hat`.
        x_hat_detached = x_hat.detach()

        # The prior for coupling coefficient, initialized as zeros.
        # b.size = [batch, out_num_caps, in_num_caps]
        b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).cuda()

        assert self.routings > 0, 'The \'routings\' should be > 0.'
        for i in range(self.routings):
            # c.size = [batch, out_num_caps, in_num_caps]
            c = F.softmax(b, dim=1)

            # At last iteration, use `x_hat` to compute `outputs` in order to backpropagate gradient
            if i == self.routings - 1:
                # c.size expanded to [batch, out_num_caps, in_num_caps, 1           ]
                # x_hat.size     =   [batch, out_num_caps, in_num_caps, out_dim_caps]
                # => outputs.size=   [batch, out_num_caps, 1,           out_dim_caps]
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
                # outputs = squash(torch.matmul(c[:, :, None, :], x_hat))  # alternative way
            else:  # Otherwise, use `x_hat_detached` to update `b`. No gradients flow on this path.
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
                # outputs = squash(torch.matmul(c[:, :, None, :], x_hat_detached))  # alternative way

                # outputs.size       =[batch, out_num_caps, 1,           out_dim_caps]
                # x_hat_detached.size=[batch, out_num_caps, in_num_caps, out_dim_caps]
                # => b.size          =[batch, out_num_caps, in_num_caps]
                b = b + torch.sum(outputs * x_hat_detached, dim=-1)

        return torch.squeeze(outputs, dim=-2)


class PrimaryCapsule(nn.Module):
    """
    Apply Conv2D with `out_channels` and then reshape to get capsules
    :param in_channels: input channels
    :param out_channels: output channels
    :param dim_caps: dimension of capsule
    :param kernel_size: kernel size
    :return: output tensor, size=[batch, num_caps, dim_caps]
    """
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        return squash(outputs)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class CombinedFeatureExtractor(nn.Module):
    def __init__(self, input_channels=1, pretrained=True, unet_feature_dim=1024, capsnet_feature_dim=1024):
        super(CombinedFeatureExtractor, self).__init__()

        # Use pre-trained ResNet as encoder but modify for desired input channels
        resnet = models.resnet18(pretrained=pretrained)
        resnet.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove the classification head

        # Decoder (U-Net style)
        self.decoder_block1 = self.conv_block(512, 256)
        self.decoder_block2 = self.conv_block(256, 128)
        self.decoder_block3 = self.conv_block(128, 64)
        
        with torch.no_grad():
            x = torch.randn(1, input_channels, 256, 256)
            x = self.encoder(x)
            x = self.decoder_block1(F.interpolate(x, scale_factor=2))
            x = self.decoder_block2(F.interpolate(x, scale_factor=2))
            x = self.decoder_block3(F.interpolate(x, scale_factor=2))
        
        # Calculate the number of features and initialize the fully connected layer
        num_features = x.shape[1] * x.shape[2] * x.shape[3]
        
        # Calculate the number of features for U-Net and initialize the fully connected layer
        self.fc_unet = nn.Linear(num_features, unet_feature_dim)  # Assuming 256x256 input size
        
        # CapsNet Layer
        self.primary_capsules = PrimaryCapsule(in_channels=512, out_channels=256, dim_caps=8, kernel_size=1)
        
        self.digit_capsules = DenseCapsule(in_num_caps=2048, 
                                           in_dim_caps=self.primary_capsules.dim_caps,
                                           out_num_caps=capsnet_feature_dim // 8,  # Adjust the number of capsules based on desired feature dimension
                                           out_dim_caps=8, 
                                           routings=3)
        
        self.fc_caps = nn.Linear(capsnet_feature_dim, capsnet_feature_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder (pre-trained CNN)
        encoded_x = self.encoder(x)
        # Decoder (U-Net)
        unet_features = self.decoder_block1(F.interpolate(encoded_x, scale_factor=2))
        unet_features = self.decoder_block2(F.interpolate(unet_features, scale_factor=2))
        unet_features = self.decoder_block3(F.interpolate(unet_features, scale_factor=2))
        
        # Flatten and pass through the fully connected layer for 1D feature vector
        unet_features = unet_features.view(unet_features.size(0), -1)
        unet_features = self.fc_unet(unet_features)
 
        x_reshaped = encoded_x.view(encoded_x.size(0), 512, 1, 1)
        x_reshaped = F.interpolate(x_reshaped, size=(8, 8), mode='nearest')  # Upsample to [batch_size, 512, 7, 7]

        print(x_reshaped.shape)
        caps_features = self.primary_capsules(x_reshaped)
        print(caps_features.shape)

        caps_features = self.digit_capsules(caps_features)
        caps_features = self.fc_caps(caps_features.view(caps_features.size(0), -1))

        # Combine the features from U-Net and CapsNet
        combined_features = torch.cat([unet_features, caps_features], dim=1)

        return unet_features, caps_features

def extract_encompassing_subregion(volume, bounding_boxes):
    # Extract the coordinates from the 'encompassing' bounding box
    top_left_front = bounding_boxes["encompassing"]["top_left_front"]
    bottom_right_back = bounding_boxes["encompassing"]["bottom_right_back"]

    # Use the coordinates to slice into the volume tensor
    subregion = volume[top_left_front[0]:bottom_right_back[0]+1, 
                       top_left_front[1]:bottom_right_back[1]+1, 
                       top_left_front[2]:bottom_right_back[2]+1]

    return subregion

In [None]:
import os
import numpy as np
import nibabel as nib
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

# Load data
train_info = pd.read_csv('train.csv')
train_info.patient_id = train_info.patient_id.astype(str)

curr_df = dicom_df[dicom_df.patient_duplicate_count == 2]
train_info = train_info[train_info.patient_id.isin(curr_df.patient_id)]
train_info = train_info[train_info.extravasation_injury == 1]
curr_df = curr_df[curr_df.patient_id.isin(train_info.patient_id)]

unique_patients = np.unique(curr_df.patient_id)

save_directory = "/workspace/0728tot/ATD/patient_nii"  # Adjust to your path

# Directory where the .nii.gz files are saved
vol_directory = "/workspace/0728tot/ATD/patient_nii"

# Iterate over unique patients
for patient_id in tqdm(unique_patients[0:1]):
    eval_df = curr_df[curr_df.patient_id == patient_id]
    paths = np.array(eval_df.series_dir)
    series_names = np.array(eval_df.series_id)
    
    volumes = []
    
    # For each series in a patient, sample the volume
    for path, series_name in zip(paths, series_names):
        vol = sample_patient_volume(path, depth_downsample_rate=None, lw_downsample_rate=None, 
                                  adjust_pixel_spacing="no", standardize_pixel_array="yes", 
                                  target_pixel_spacing=[1.0, 1.0], target_thickness=1)
        volumes.append((series_name, vol))
    
    # Filtering based on depth criterion
    depths = [vol.shape[0] for _, vol in volumes]
    max_depth = max(depths)
    
    # Remove volumes that don't meet the depth criterion from the list
    volumes = [(series_name, vol) for series_name, vol in volumes if vol.shape[0] >= 0.4 * max_depth]

    # Save volumes and keep track of saved filepaths
    saved_filepaths = []
    
    for series_name, vol in volumes:
        filename = os.path.join(vol_directory, f"{series_name}.nii.gz")
        save_nii(vol, filename)
        saved_filepaths.append(filename)
    
    # Create a data list for the Dataset using the saved_filepaths list
    data_list = [{"image": filepath} for filepath in saved_filepaths]
    
    # Create DataLoader
    test_org_ds = Dataset(data=data_list, transform=test_org_transforms)
    test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=0)
    
    # Initialize an empty list to store bounding box data
    bbox_data_list = []

    with torch.no_grad():
        # Loop through both filepaths (from DataLoader) and volumes
        for (test_data, (series_name, vol)) in zip(test_org_loader, volumes):

            test_inputs = test_data["image"].to(model_device)
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            curr_data = sliding_window_inference(test_inputs, roi_size, sw_batch_size, seg_model)
            ideal_size = list(test_data['image_meta_dict']['spatial_shape'][0].cpu().detach().numpy())
            print(ideal_size)
            # Process the data
            segmented = curr_data.cpu()
            print(segmented.shape)
            upsampled = F.interpolate(segmented, size=[ideal_size[0], segmented.shape[3], segmented.shape[4]], mode='trilinear', align_corners=True)

            merged = torch.argmax(upsampled, dim=1) + 1  # add 1 to move the range from 0-5 to 1-6

            print(merged.shape)

            bounding_boxes_result_v3 = process_and_get_bounding_boxes_v6(merged)


            extracted_vol = extract_encompassing_subregion(vol, bounding_boxes_result_v3)


In [5]:
# Initialize the CapsuleFeatureExtractor
feature_extractor = CombinedFeatureExtractor(input_channels = 3, pretrained=True, unet_feature_dim=1024, capsnet_feature_dim=1024)
feature_extractor.eval()

The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|████████████████████████████████████████████████| 44.7M/44.7M [00:00<00:00, 114MB/s]


CombinedFeatureExtractor(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_ru