# We need to explore the encoded vector of the nnUNet to create an decoder to regress / classify the output instead of segmentation.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
print(os.getcwd().split('/')[-1])
if os.getcwd().split('/')[-1] == 'notebooks':
    sys.path.append('../pytorch')
    print('Added path to sys')


In [None]:
# Getting the nnUNet model
from utils import KFoldNNUNetSegmentationDataModule, GenesisSegmentation
import yaml
import torch

# Load the config file
config = yaml.load(open('../pytorch/configs/fine_tune_config-regression.yaml', 'r'), Loader=yaml.FullLoader)

# Model
model = GenesisSegmentation(config=config).to('cpu')

In [None]:
# Investigating the encoder
encoder = model.model.encoder 

input_sample = torch.randn(1, 1, *config['data']['patch_size'])
output_sample = encoder(input_sample)
[feat.shape for feat in output_sample]

We would like to know if the Whole-Volumetric-OCT-Image can be cropped to a smaller size.

In [None]:
import SimpleITK as sitk
from pathlib import Path
import numpy as np

def get_positive_bounds(image):
    image = sitk.GetArrayFromImage(image)
    zs, xs, ys = np.where(image > 0)
    for z, x, y in zip(zs, xs, ys):
        assert image[z, x, y] > 0

    return zs.min(), zs.max(), xs.min(), xs.max(), ys.min(), ys.max()

max_z = -1
min_z = 1000
max_x = -1
min_x = 1000
max_y = -1
min_y = 1000

min_img_z, max_img_z, min_img_x, max_img_x, min_img_y, max_img_y = 1000, -1, 1000, -1, 1000, -1

data_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/imagesTr/').glob('*.nii.gz')
label_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/labelsTr/').glob('*.nii.gz')
print('train')
for file in data_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_img_z = min(min_img_z, positive_bounds[0])
    max_img_z = max(max_img_z, positive_bounds[1])
    min_img_x = min(min_img_x, positive_bounds[2])
    max_img_x = max(max_img_x, positive_bounds[3])
    min_img_y = min(min_img_y, positive_bounds[4])
    max_img_y = max(max_img_y, positive_bounds[5])

print('label')
for file in label_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_z = min(min_z, positive_bounds[0])
    max_z = max(max_z, positive_bounds[1])
    min_x = min(min_x, positive_bounds[2])
    max_x = max(max_x, positive_bounds[3])
    min_y = min(min_y, positive_bounds[4])
    max_y = max(max_y, positive_bounds[5])

data_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/imagesTs/').glob('*.nii.gz')
label_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/labelsTs/').glob('*.nii.gz')
print('test')
for file in data_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_img_z = min(min_img_z, positive_bounds[0])
    max_img_z = max(max_img_z, positive_bounds[1])
    min_img_x = min(min_img_x, positive_bounds[2])
    max_img_x = max(max_img_x, positive_bounds[3])
    min_img_y = min(min_img_y, positive_bounds[4])
    max_img_y = max(max_img_y, positive_bounds[5])

for file in label_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_z = min(min_z, positive_bounds[0])
    max_z = max(max_z, positive_bounds[1])
    min_x = min(min_x, positive_bounds[2])
    max_x = max(max_x, positive_bounds[3])
    min_y = min(min_y, positive_bounds[4])
    max_y = max(max_y, positive_bounds[5])

print('Minimum croppable bound is: ')
print(min_z, max_z, min_x, max_x, min_y, max_y)

As we can see from the image data, the image is 3D and the size cannot be cropped to a smaller size. As for the label, it can be cropped.

In [None]:
# Define the UNetRegressor
class UNetRegressor(torch.nn.Module):
    def __init__(self, n_classes: int, task: str = 'regression'):
        super(UNetRegressor, self).__init__()
        self.encoder = model.model.encoder
        """
        >>> input_sample = torch.randn(1, 1, 512, 512, 384)
        >>> output_sample = encoder(input_sample)
        >>> [feat.shape for feat in output_sample]
        >>> [torch.Size([1, 32, 512, 512, 384]),
             torch.Size([1, 64, 256, 256, 192]),
             torch.Size([1, 128, 128, 128, 96]),
             torch.Size([1, 256, 64, 64, 48]),
             torch.Size([1, 320, 32, 32, 24]),
             torch.Size([1, 320, 32, 16, 12])]
        """
        # self.decoder = model.model.decoder
        self.regressor = torch.nn.Sequential(
            torch.nn.Conv3d(320, 320, kernel_size=(32, 16, 12), stride=1, bias=True),
            torch.nn.ReLU(),
            torch.nn.Conv3d(320, 160, kernel_size=1, stride=1, bias=True),
            torch.nn.ReLU(),
            torch.nn.Conv3d(160, n_classes, kernel_size=1, stride=1, bias=True),
        )
        self.last_activation = torch.nn.Sigmoid() if task == 'regression' else torch.nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.encoder(x)[-1]
        x = self.regressor(x)
        return x

In [None]:
# Define the UNetRegressor
class UNetRegressorHead(torch.nn.Module):
    def __init__(self, in_channels: int, n_classes: int, pooling="avg", dropout=0.2, task: str = 'regression'):
        super(UNetRegressorHead, self).__init__()
        """
        >>> input_sample = torch.randn(1, 1, 512, 512, 384)
        >>> output_sample = encoder(input_sample)
        >>> [feat.shape for feat in output_sample]
        >>> [torch.Size([1, 32, 512, 512, 384]),
             torch.Size([1, 64, 256, 256, 192]),
             torch.Size([1, 128, 128, 128, 96]),
             torch.Size([1, 256, 64, 64, 48]),
             torch.Size([1, 320, 32, 32, 24]),
             torch.Size([1, 320, 32, 16, 12])]
        """
        if pooling not in ("max", "avg"):
            raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))

        if task not in ("regression", "classification"):
            raise ValueError("Task should be one of ('regression', 'classification'), got {}.".format(task))

        self.regressor = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool3d(1) if pooling == "avg" else torch.nn.AdaptiveMaxPool3d(1),
            torch.nn.Flatten(),
            torch.nn.Dropout(p=dropout, inplace=True) if dropout > 0 else torch.nn.Identity(),
            torch.nn.Linear(in_channels, n_classes, bias=True),
            torch.nn.ReLU() if task == 'regression' else torch.nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.regressor(x)
        return x

In [None]:
encoder = model.model.encoder
input_sample = torch.randn(10, 1, *config['data']['patch_size'])
output_sample = encoder(input_sample)
regressor = UNetRegressorHead(in_channels=320, n_classes=2, pooling="avg", dropout=0.2, task='classification')
output_sample = regressor(output_sample[-1])
output_sample.shape, output_sample

In [None]:
img.GetPixel()