In [2]:
import numpy as np
import torch

here is the code of the tiramisu model


In [35]:
import torch
import torch.nn as nn


class DenseLayer(nn.Sequential):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.add_module("norm", nn.BatchNorm2d(in_channels))
        self.add_module("silu", nn.SiLU(inplace=True))
        self.add_module(
            "conv",
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=True),
        )
        self.add_module("drop", nn.Dropout2d(0.2))

    def forward(self, x):
        return super().forward(x)


class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.layers = nn.ModuleList([DenseLayer(in_channels + i * growth_rate, growth_rate) for i in range(n_layers)])

    def forward(self, x):
        if self.upsample:
            new_features = []
            # we pass all previous activations into each dense layer normally
            # But we only store each dense layer's output in the new_features array
            for layer in self.layers:
                out = layer(x)
                x = torch.cat([x, out], 1)
                new_features.append(out)
            return torch.cat(new_features, 1)
        else:
            for layer in self.layers:
                out = layer(x)
                x = torch.cat([x, out], 1)  # 1 = channel axis
            return x


class TransitionDown(nn.Sequential):
    def __init__(self, in_channels):
        super().__init__()
        self.add_module("norm", nn.BatchNorm2d(num_features=in_channels))
        self.add_module("SiLU", nn.SiLU(inplace=True))
        self.add_module(
            "conv",
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=True),
        )
        self.add_module("drop", nn.Dropout2d(0.2))
        self.add_module("maxpool", nn.MaxPool2d(2))

    def forward(self, x):
        return super().forward(x)


class TransitionUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convTrans = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=2,
            padding=0,
            bias=True,
        )

    def forward(self, x, skip):
        out = self.convTrans(x)
        out = center_crop(out, skip.size(2), skip.size(3))
        out = torch.cat([out, skip], 1)
        return out


class Bottleneck(nn.Sequential):
    def __init__(self, in_channels, growth_rate, n_layers):
        super().__init__()
        self.add_module("bottleneck", DenseBlock(in_channels, growth_rate, n_layers, upsample=True))

    def forward(self, x):
        return super().forward(x)


def center_crop(layer, max_height, max_width):
    _, _, h, w = layer.size()
    xy1 = (w - max_width) // 2
    xy2 = (h - max_height) // 2
    return layer[:, :, xy2 : (xy2 + max_height), xy1 : (xy1 + max_width)]


class FCDenseNet(nn.Module):
    def __init__(
        self,
        in_channels=3,
        down_blocks=(4, 4, 4, 4, 4),
        up_blocks=(4, 4, 4, 4, 4),
        bottleneck_layers=4,
        growth_rate=12,
        out_chans_first_conv=48,
        n_classes=2,
    ):
        super().__init__()
        self.down_blocks = down_blocks
        self.up_blocks = up_blocks
        cur_channels_count = 0
        skip_connection_channel_counts = []

        ## First Convolution ##

        self.add_module(
            "firstconv",
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_chans_first_conv,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=True,
            ),
        )
        cur_channels_count = out_chans_first_conv

        #####################
        # Downsampling path #
        #####################

        self.denseBlocksDown = nn.ModuleList([])
        self.transDownBlocks = nn.ModuleList([])
        for i in range(len(down_blocks)):
            self.denseBlocksDown.append(DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
            cur_channels_count += growth_rate * down_blocks[i]
            skip_connection_channel_counts.insert(0, cur_channels_count)
            self.transDownBlocks.append(TransitionDown(cur_channels_count))

        #####################
        #     Bottleneck    #
        #####################

        self.add_module("bottleneck", Bottleneck(cur_channels_count, growth_rate, bottleneck_layers))
        prev_block_channels = growth_rate * bottleneck_layers
        cur_channels_count += prev_block_channels

        #######################
        #   Upsampling path   #
        #######################

        self.transUpBlocks = nn.ModuleList([])
        self.denseBlocksUp = nn.ModuleList([])
        for i in range(len(up_blocks) - 1):
            self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels))
            cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]

            self.denseBlocksUp.append(DenseBlock(cur_channels_count, growth_rate, up_blocks[i], upsample=True))
            prev_block_channels = growth_rate * up_blocks[i]
            cur_channels_count += prev_block_channels

        ## Final DenseBlock ##

        self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels))
        cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1]

        self.denseBlocksUp.append(DenseBlock(cur_channels_count, growth_rate, up_blocks[-1], upsample=False))
        cur_channels_count += growth_rate * up_blocks[-1]

        ## Softmax ##

        self.finalConv = nn.Conv2d(
            in_channels=cur_channels_count,
            out_channels=n_classes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = self.firstconv(x)

        skip_connections = []
        for i in range(len(self.down_blocks)):
            out = self.denseBlocksDown[i](out)
            skip_connections.append(out)
            out = self.transDownBlocks[i](out)
            
        ################### get the output of the bottleneck layer ##########################
        out = self.bottleneck(out)
        bottleneck_output = out

        for i in range(len(self.up_blocks)):
            skip = skip_connections.pop()
            out = self.transUpBlocks[i](out, skip)
            out = self.denseBlocksUp[i](out)

        out = self.finalConv(out)
        out = self.softmax(out)
        return out, bottleneck_output

# the class 0 is the background, 1 is RV, 2 is MYO, 3 is LV
tiramisu = FCDenseNet(in_channels=1, n_classes=4) # the class 0 is the background

In [36]:


tiramisu_trained_model = torch.load("tiramisu_acdc.pt", weights_only=False)
tiramisu_state_dict = tiramisu_trained_model.state_dict()
# remove '_orig_mod." from the keys of the state_dict
tiramisu_state_dict = {k.replace('_orig_mod.', ''): v for k, v in tiramisu_state_dict.items()}
tiramisu.load_state_dict(tiramisu_state_dict)
tiramisu.eval()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tiramisu = tiramisu.to(device)

FCDenseNet(
  (firstconv): Conv2d(1, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (denseBlocksDown): ModuleList(
    (0): DenseBlock(
      (layers): ModuleList(
        (0): DenseLayer(
          (norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (silu): SiLU(inplace=True)
          (conv): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (1): DenseLayer(
          (norm): BatchNorm2d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (silu): SiLU(inplace=True)
          (conv): Conv2d(60, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (2): DenseLayer(
          (norm): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (silu): SiLU(inplace=True)
          (conv): Conv2d(72, 12, kernel_size=(3, 3), strid

In [40]:
example_input = torch.randn(1, 1, 128, 128)
output, bottleneck_output = tiramisu(example_input)

In [42]:
print(bottleneck_output.shape)

torch.Size([1, 48, 4, 4])


# make prediction for the table2 data







In [None]:
import numpy as np
import nrrd
import glob
import os
from natsort import natsorted
from skimage.transform import resize

def crop_resize_image(image, new_dim=256):
    """
    Process a 3D numpy image by removing non-zero background, cropping to square,
    resizing, and saving crop_index as slice objects for restoration.
    
    Parameters:
    image (np.ndarray): Input image of shape (x, y, z)
    new_dim (int): Desired dimension for the output square image (new_dim, new_dim, z)
    Returns:
    tuple: (processed_image, restore_info)
        - processed_image: Processed image of shape (new_dim, new_dim, z)
        - restore_info: Dict containing original_shape, crop_index, original_dim, new_dim
    """
    # Step 1: Remove non-zero background using np.nonzero
    nz = np.nonzero(image)
    
    
    # Get min and max indices
    min_indices = np.min(nz, axis=1)
    max_indices = np.max(nz, axis=1)
    
    # Create crop index for non-zero region
    crop_index = tuple(slice(imin, imax + 1) for imin, imax in zip(min_indices, max_indices))
    
    # Crop to non-zero region
    cropped = image[crop_index]
    
    # Step 2: Cut to min dimension of x, y to make square
    crop_h, crop_w = cropped.shape[:2]
    min_dim = min(crop_h, crop_w)
    
    # Calculate center crop
    start_h = (crop_h - min_dim) // 2
    start_w = (crop_w - min_dim) // 2
    
    square = cropped[start_h:start_h+min_dim, start_w:start_w+min_dim, :]
    
    # Calculate origin indices after square crop as slices
    orig_min_row = min_indices[0] + start_h
    orig_max_row = orig_min_row + min_dim
    orig_min_col = min_indices[1] + start_w
    orig_max_col = orig_min_col + min_dim
    orig_min_z = min_indices[2]
    orig_max_z = max_indices[2] + 1
    
    crop_index = (
        slice(orig_min_row, orig_max_row),
        slice(orig_min_col, orig_max_col),
        slice(orig_min_z, orig_max_z)
    )
    
    # Step 3: Resize to new_dim
    current_dim = square.shape[0]
    
    if current_dim != new_dim:
        resized = resize(square, (new_dim, new_dim, square.shape[2]), 
                       order=1,  # Linear interpolation
                       anti_aliasing=True,
)
        # Ensure output dtype matches input
        resized = resized.astype(square.dtype)
    else:
        resized = square.copy()
    
    # Step 4: Save crop_index for restoration
    restore_info = {
        'original_shape': image.shape,
        'crop_index': crop_index,
        'original_dim': current_dim,
        'new_dim': new_dim
    }
    
    return resized, restore_info

def crop_resize_mask(mask, restore_info):
    """
    Process a 3D numpy segmentation mask using restore_info from image processing,
    cropping to the same square region and resizing to the same dimension.
    
    Parameters:
    mask (np.ndarray): Input segmentation mask of shape (x, y, z), same shape as original image
    restore_info (dict): Restoration information from process_image, containing
                        original_shape, crop_index, original_dim, new_dim
    
    Returns:
    tuple: (processed_mask, restore_info)
        - processed_mask: Processed mask of shape (new_dim, new_dim, z)
        - restore_info: Same restore_info for consistency in restoration
    """
    # Validate mask shape
    if mask.shape != restore_info['original_shape']:
        raise ValueError("Mask shape must match original image shape")
    
    # Step 1: Crop to the square region using crop_index
    crop_index = restore_info['crop_index']
    square = mask[crop_index]
    
    # Step 2: Resize to new_dim using skimage
    current_dim = square.shape[0]
    new_dim = restore_info['new_dim']
    
    if current_dim != new_dim:
        resized = resize(square, output_shape=(new_dim, new_dim, square.shape[2]),
                       order=0,
                       anti_aliasing=False)
        # Ensure output dtype matches input
        resized = resized.astype(np.uint8)
    else:
        resized = square.copy()

    return resized


def restore_mask(processed_mask, restore_info):
    """
    Restore a processed 3D segmentation mask back to its original shape.
    
    Improvements:
    - Uses nearest-neighbor interpolation (order=0)
    - Disables anti-aliasing (to preserve discrete labels)
    - Ensures labels are integers after restoration
    - Handles edge cases (padding, rounding) more robustly
    """

    # Step 1: Resize back to the original square dimension (nearest-neighbor)
    current_dim = processed_mask.shape[0]
    original_dim = restore_info['original_dim']
    
    if current_dim != original_dim:
        resized = resize(
            processed_mask,
            output_shape=(original_dim, original_dim, processed_mask.shape[2]),
            order=0,               # nearest neighbor → preserves labels
            anti_aliasing=False,   # turn off to avoid soft edges
            preserve_range=True    # keep label values as-is
        )
    else:
        resized = processed_mask.copy()
    
    # Step 2: Initialize empty mask of the original shape
    original_shape = restore_info['original_shape']
    restored = np.zeros(original_shape, dtype=np.int16)
    
    # Step 3: Paste the restored square into its original crop position
    crop_index = restore_info['crop_index']
    restored[crop_index] = resized.astype(np.int16)  # round & ensure int labels

    return restored



def predict_patches(images, model, num_classes=4, batch_size=8, device="cuda"):
    """return the patches"""
    prediction = torch.zeros(
        (images.size(0), num_classes, images.size(2), images.size(3)),
        device=device,
    )

    batch_start = 0
    batch_end = batch_size
    while batch_start < images.size(0):
        image = images[batch_start:batch_end]
        with torch.no_grad():
            image = image.to(device)
            y_pred = model(image)
            prediction[batch_start:batch_end] = y_pred
        batch_start += batch_size
        batch_end += batch_size
    return prediction.cpu().numpy()

def predict_data_model(data, model):
    probability_output = predict_patches(data["image"], model) # shape (n, 5, 128, 128)
    seg = np.argmax(probability_output, axis=1).transpose(1, 2, 0)  # shape (128, 128, n)
    seg = remove_small_elements(seg, min_size_remove=300)
    invert_seg = restore_mask(seg, data["restore_info"])

    return invert_seg

def make_volume(ndarray, voxel_spacing):
    volume = np.prod(voxel_spacing) * (ndarray.sum())
    return volume

def compute_dice(mask1, mask2, class_id, eps=1e-6):
    mask1_binary = mask1==class_id
    mask2_binary = mask2==class_id
    intersection = np.sum(mask1_binary*mask2_binary)
    union = np.sum(mask1_binary)+np.sum(mask2_binary)
    return (2*intersection+eps)/(union+eps)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_acdc = torch.load("tiramisu_acdc.pt", weights_only=False)
model_acdc.eval()
model_acdc = model_acdc.to(device)

def preprocess_data_table2(image_nrrd_path):
    data = {}
    patient_image = nrrd.read(image_nrrd_path)
    image = patient_image[0]
    image = min_max_normalize(image)

    resized_image, restore_info = crop_resize_image(image, cfg.DATA.DIM_RESIZE)
    # padded_mask = pad_background_with_index(mask, crop_index, padded_index, dim2pad=cfg.DATA.DIM2PAD)
    data["restore_info"] = restore_info
    batch_images = []
    for i in range(resized_image.shape[-1]):
        slice_inputs = resized_image[..., i : i + 1]  # shape (224, 224, 1)
        slices_image = torch.from_numpy(slice_inputs.transpose(-1, 0, 1))  # shape (1, 224, 224)
        batch_images.append(slices_image)

    batch_images = torch.stack(batch_images).float()  # shape (9,1, 224, 224)
    data["image"] = batch_images
    return data
    
table2_path = "Tables 2/Tables/*/"
list_test_image, list_test_mask = [], []
for path in natsorted(glob.glob(table2_path)):
    # go into the path and find nrrd files
    for file in os.listdir(path):
        if file.endswith(".nrrd") and "seg" not in file:
            list_test_image.append(path + file)
        elif file.endswith("seg.nrrd"):
            list_test_mask.append(path + file)




In [None]:
# create folder to save the predicted masks
os.makedirs("predicted_table2_data", exist_ok=True)
for index in range(len(list_test_image)):
    image = nrrd.read(list_test_image[index])[0]
    image_info = nrrd.read(list_test_image[index])[1]
    data = preprocess_data_table2(list_test_image[index])

    # seg = predict_data(data, segmenter, patient=patient, mvo=is_MVO, task=task).astype(np.uint8)
    seg = predict_data_model(data, model_acdc).astype(np.uint8)
    # give label 1, 3, 4 to 0
    seg[seg == 1] = 0
    seg[seg == 3] = 0
    seg[seg == 4] = 0
    # give label 2 to 1
    seg[seg == 2] = 1
    # write the predicted mask to the a nrrd file with image_info
    # and make name of the file is like this 1 - PGF/34 de_high res PSIR EC_PSIR_tiramisu.nrrd
    name_predict = list_test_image[index].split("/")[-2:]
    name_predict = "/".join(name_predict).replace(".nrrd", "_tiramisu_seg.nrrd")
    # make folder dir 1-PGF to save the predicted mask
    os.makedirs(f"predicted_table2_data/{name_predict.split('/')[0]}", exist_ok=True)

    nrrd.write(f"predicted_table2_data/{name_predict}", seg, image_info)
    print("write predicted mask to ", name_predict)
    # break
    