# We need to explore the encoded vector of the nnUNet to create an decoder to regress / classify the output instead of segmentation.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
print(os.getcwd().split('/')[-1])
if os.getcwd().split('/')[-1] == 'notebooks':
    sys.path.append('../pytorch')
    print('Added path to sys')


notebooks
Added path to sys


In [None]:
# Getting the nnUNet model
from utils import KFoldNNUNetSegmentationDataModule, GenesisSegmentation
import yaml
import torch

# Load the config file
config = yaml.load(open('../pytorch/configs/fine_tune_config-regression.yaml', 'r'), Loader=yaml.FullLoader)

# Model
model = GenesisSegmentation(config=config).to('cpu')

In [None]:
# Investigating the encoder
encoder = model.model.encoder 

input_sample = torch.randn(1, 1, *config['data']['patch_size'])
output_sample = encoder(input_sample)
[feat.shape for feat in output_sample]

We would like to know if the Whole-Volumetric-OCT-Image can be cropped to a smaller size.

In [None]:
import SimpleITK as sitk
from pathlib import Path
import numpy as np

def get_positive_bounds(image):
    image = sitk.GetArrayFromImage(image)
    zs, xs, ys = np.where(image > 0)
    for z, x, y in zip(zs, xs, ys):
        assert image[z, x, y] > 0

    return zs.min(), zs.max(), xs.min(), xs.max(), ys.min(), ys.max()

max_z = -1
min_z = 1000
max_x = -1
min_x = 1000
max_y = -1
min_y = 1000

min_img_z, max_img_z, min_img_x, max_img_x, min_img_y, max_img_y = 1000, -1, 1000, -1, 1000, -1

data_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/imagesTr/').glob('*.nii.gz')
label_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/labelsTr/').glob('*.nii.gz')
print('train')
for file in data_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_img_z = min(min_img_z, positive_bounds[0])
    max_img_z = max(max_img_z, positive_bounds[1])
    min_img_x = min(min_img_x, positive_bounds[2])
    max_img_x = max(max_img_x, positive_bounds[3])
    min_img_y = min(min_img_y, positive_bounds[4])
    max_img_y = max(max_img_y, positive_bounds[5])

print('label')
for file in label_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_z = min(min_z, positive_bounds[0])
    max_z = max(max_z, positive_bounds[1])
    min_x = min(min_x, positive_bounds[2])
    max_x = max(max_x, positive_bounds[3])
    min_y = min(min_y, positive_bounds[4])
    max_y = max(max_y, positive_bounds[5])

data_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/imagesTs/').glob('*.nii.gz')
label_dir = Path('/storage_bizon/naravich/nnUNet_Datasets/nnUNet_raw/Dataset303_Calcium_OCTv2/labelsTs/').glob('*.nii.gz')
print('test')
for file in data_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_img_z = min(min_img_z, positive_bounds[0])
    max_img_z = max(max_img_z, positive_bounds[1])
    min_img_x = min(min_img_x, positive_bounds[2])
    max_img_x = max(max_img_x, positive_bounds[3])
    min_img_y = min(min_img_y, positive_bounds[4])
    max_img_y = max(max_img_y, positive_bounds[5])

for file in label_dir:
    img = sitk.ReadImage(str(file))
    print(img.GetSize(), file.stem)
    positive_bounds = get_positive_bounds(img)
    min_z = min(min_z, positive_bounds[0])
    max_z = max(max_z, positive_bounds[1])
    min_x = min(min_x, positive_bounds[2])
    max_x = max(max_x, positive_bounds[3])
    min_y = min(min_y, positive_bounds[4])
    max_y = max(max_y, positive_bounds[5])

print('Minimum croppable bound is: ')
print(min_z, max_z, min_x, max_x, min_y, max_y)

As we can see from the image data, the image is 3D and the size cannot be cropped to a smaller size. As for the label, it can be cropped.

In [None]:
# Define the UNetRegressor
class UNetRegressor(torch.nn.Module):
    def __init__(self, n_classes: int, task: str = 'regression'):
        super(UNetRegressor, self).__init__()
        self.encoder = model.model.encoder
        """
        >>> input_sample = torch.randn(1, 1, 512, 512, 384)
        >>> output_sample = encoder(input_sample)
        >>> [feat.shape for feat in output_sample]
        >>> [torch.Size([1, 32, 512, 512, 384]),
             torch.Size([1, 64, 256, 256, 192]),
             torch.Size([1, 128, 128, 128, 96]),
             torch.Size([1, 256, 64, 64, 48]),
             torch.Size([1, 320, 32, 32, 24]),
             torch.Size([1, 320, 32, 16, 12])]
        """
        # self.decoder = model.model.decoder
        self.regressor = torch.nn.Sequential(
            torch.nn.Conv3d(320, 320, kernel_size=(32, 16, 12), stride=1, bias=True),
            torch.nn.ReLU(),
            torch.nn.Conv3d(320, 160, kernel_size=1, stride=1, bias=True),
            torch.nn.ReLU(),
            torch.nn.Conv3d(160, n_classes, kernel_size=1, stride=1, bias=True),
        )
        self.last_activation = torch.nn.Sigmoid() if task == 'regression' else torch.nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.encoder(x)[-1]
        x = self.regressor(x)
        return x

In [None]:
# Define the UNetRegressor
class UNetRegressorHead(torch.nn.Module):
    def __init__(self, in_channels: int, n_classes: int, pooling="avg", dropout=0.2, task: str = 'regression'):
        super(UNetRegressorHead, self).__init__()
        """
        >>> input_sample = torch.randn(1, 1, 512, 512, 384)
        >>> output_sample = encoder(input_sample)
        >>> [feat.shape for feat in output_sample]
        >>> [torch.Size([1, 32, 512, 512, 384]),
             torch.Size([1, 64, 256, 256, 192]),
             torch.Size([1, 128, 128, 128, 96]),
             torch.Size([1, 256, 64, 64, 48]),
             torch.Size([1, 320, 32, 32, 24]),
             torch.Size([1, 320, 32, 16, 12])]
        """
        if pooling not in ("max", "avg"):
            raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))

        if task not in ("regression", "classification"):
            raise ValueError("Task should be one of ('regression', 'classification'), got {}.".format(task))

        self.regressor = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool3d(1) if pooling == "avg" else torch.nn.AdaptiveMaxPool3d(1),
            torch.nn.Flatten(),
            torch.nn.Dropout(p=dropout, inplace=True) if dropout > 0 else torch.nn.Identity(),
            torch.nn.Linear(in_channels, n_classes, bias=True),
            torch.nn.ReLU() if task == 'regression' else torch.nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.regressor(x)
        return x

In [None]:
encoder = model.model.encoder
input_sample = torch.randn(10, 1, *config['data']['patch_size'])
output_sample = encoder(input_sample)
regressor = UNetRegressorHead(in_channels=320, n_classes=2, pooling="avg", dropout=0.2, task='classification')
output_sample = regressor(output_sample[-1])
output_sample.shape, output_sample

## Prepare the dataloader

In [None]:
from pathlib import Path
import pandas as pd

unlabeled_data_dir = Path('/storage_bizon/naravich/Unlabeled_OCT_by_CADx/') # Yiqing filtered data

In [None]:
pre_ivl = pd.read_excel('tabular_data/T1A_PRE_QUANT_LESION.xlsx', skiprows=4)
pre_ivl.head()

### Rename columns to their code names `FLAG_CAL = Flag for calcified nodules` -> `FLAG_CAL`

In [None]:
column_names = pre_ivl.columns
column_names = {column_name:column_name.split(' = ')[0] for column_name in column_names}
column_names
pre_ivl.rename(columns=column_names, inplace=True)
pre_ivl.head()

### Extract the Unique Subject ID (USUBJID) to the format we use in the dataset `CP 61774-105-001` -> `105-001`
That is [A-Z]{2} [0-9]{5}-[0-9]{3}-[0-3]{3} -> [0-9]{3}-[0-9]{3}

In [None]:
import re
import unidecode

def format_subject_id(subject_id: str):
    # Define the pattern
    pattern = r"[A-Z]{2} [0-9]{4,5}-([0-9]{3}-[0-9]{3})"

    # Compile the pattern
    regex = re.compile(pattern)

    # Search for the pattern in the input string
    match = regex.search(unidecode.unidecode(subject_id))

    # If a match is found, extract the desired part using the replacement pattern
    if match:
        extracted_part = re.sub(r".*?([0-9]{3}-[0-9]{3})", r"\1", match.group(1))
        return extracted_part
    raise ValueError(f"No match found {subject_id.strip().replace(u'\xa0', ' ')}")

pre_ivl['USUBJID'].apply(lambda x: format_subject_id(x))

In [None]:
pre_ivl['USUBJID'] = pre_ivl['USUBJID'].apply(lambda x: format_subject_id(x))
pre_ivl.head()

### Select only the columns we are interested in

```
MLAS_AS = Area stenosis (%)	
MLAS_SCA = Superficial calcium arc (°)	
MLAS_MSCT = Maximum superficial calcium thickness (mm)	
MCS_LA = Maximum calcium site -  Lumen area (mm2)	
MCS_AS = Maximum calcium site - Area stenosis (%)	
MCS_SCA = Maximum calcium site -  Superficial calcium arc (°)	
MCS_MSCT = Maximum calcium site -   Maximum superficial calcium thickness (mm)	
MCCS_LA = Maximum continuous calcium site - Lumen area (mm2)	
MCCS_AS = Maximum continuous calcium site - Area stenosis (%)	
MCCS_SCA = Maximum continuous calcium site - Superficial calcium arc (°)	
MCCS_MSCT = Maximum continuous calcium site - Maximum superficial calcium thickness (mm)	
MCCS_MINSCT = Maximum continuous calcium site - Minimum superficial calcium thickness (mm)	
MCCS_CSCA = Maximum continuous calcium site - Circumferential superficial calcium	
MCCS_CSCA_270 = Maximum continuous calcium site - Length of circumferential superficial calcium greater than equal to 270(mm)	
MCCS_CSCA_180 = Maximum continuous calcium site - Length of circumferential superficial calcium greater than equal to 180(mm)
FMSA_LA = Final minimum stent area site - Lumen area (mm2)	
FMSA_AS = Final minimum stent area site - Area stenosis (%)	
FMSA_SCA = Final minimum stent area site - Superficial calcium arc (°)	
FMSA_MSCT = Final minimum stent area site - Maximum superficial calcium thickness (mm)
```

In [None]:
selected_columns = [
    'USUBJID',
    'STUDY',
    'MLAS_AS',
    'MLAS_SCA',
    'MLAS_MSCT',
    'MCS_LA',
    'MCS_AS',
    'MCS_SCA',
    'MCS_MSCT',
    'MCCS_LA',
    'MCCS_AS',
    'MCCS_SCA',
    'MCCS_MSCT',
    'MCCS_MINSCT',
    'MCCS_CSCA',
    'MCCS_CSCA_270',
    'MCCS_CSCA_180',
    'FMSA_LA',
    'FMSA_AS',
    'FMSA_SCA',
    'FMSA_MSCT',
]

pre_ivl = pre_ivl[selected_columns]
pre_ivl.head()

### Just to make every easy for the dataloader, we will put the absolute path of the image to the DataFrame

Check that all the images are named with 'Final', 'Pre', or 'Post'

In [None]:
image_path = unlabeled_data_dir / 'NiFTI'
image_path = image_path.glob('*.nii.gz')
for i in image_path:
    if 'Final' in i.stem or 'Pre' in i.stem or 'Post' in i.stem:
        continue
    else:
        print(i.stem)

In [None]:
def resolve_image_path(subject_id: str):
    image_path = unlabeled_data_dir / 'NiFTI'
    image_path = image_path.glob("{}Pre*".format(subject_id.replace('-', '')))
    image_path = list(image_path)
    if not image_path:
        return None
    return image_path[0]

pre_ivl['image_path'] = pre_ivl['USUBJID'].apply(lambda x: resolve_image_path(x))
pre_ivl.dropna(subset=['image_path'], inplace=True)
pre_ivl.head()


### Lastly fill . with null values

In [None]:
import numpy as np
pre_ivl.replace({'\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0.': np.nan}, inplace=False).to_csv('tabular_data/pre_ivl.csv', index=False)

## Do the same thing for post stent

In [None]:
import numpy as np
import re
import unidecode

post_stent = pd.read_excel('tabular_data/T6_POST_STENT_LESION.xlsx', skiprows=4)

column_names = post_stent.columns
column_names = {column_name:column_name.split(' = ')[0] for column_name in column_names}

post_stent.rename(columns=column_names, inplace=True)


def format_subject_id(subject_id: str):
    # Define the pattern
    pattern = r"[A-Z]{2} [0-9]{4,5}-([0-9]{3}-[0-9]{3})"

    # Compile the pattern
    regex = re.compile(pattern)

    # Search for the pattern in the input string
    match = regex.search(unidecode.unidecode(subject_id))

    # If a match is found, extract the desired part using the replacement pattern
    if match:
        extracted_part = re.sub(r".*?([0-9]{3}-[0-9]{3})", r"\1", match.group(1))
        return extracted_part
    raise ValueError(f"No match found {subject_id.strip().replace(u'\xa0', ' ')}")

post_stent['USUBJID'] = post_stent['USUBJID'].apply(lambda x: format_subject_id(x))

selected_columns = [
    'USUBJID',
    'STUDY',

    # Categorical
    'MAL_PRES',
    'MAL_PROX',
    'MAL_DIS',
    'MAL_SBOD',
    'MMAL_CF',
    'MMAL_NCF',
    'CF_PRES',
    'CF_3',
    'CF_2',
    'CF_1',

    # Morphological
    'MMAL_SA',
    'MMAL_LA',
    'MMAL_ARC',
    'MMAL_AR',
    'MMAL_PAR',
    'MAL_LEN',
    'MAL_THICK',
    'TOT_CAL0',
    'TOT_CAL',
    'MEAN_CF',
    'MAX_CF',
    'TOT_CFLEN',
    'MAX_CFDEP',
    'MAX_CFWID',
    'MAX_CFTHK',
    'MAX_CARC',
    'MIN_CARC',
]

post_stent = post_stent[selected_columns]

image_path = unlabeled_data_dir / 'NiFTI'
image_path = image_path.glob('*.nii.gz')
for i in image_path:
    if 'Final' in i.stem or 'Pre' in i.stem or 'Post' in i.stem:
        continue
    else:
        print(i.stem)

def resolve_image_path(subject_id: str):
    image_path = unlabeled_data_dir / 'NiFTI'
    image_path = image_path.glob("{}Final*".format(subject_id.replace('-', '')))
    image_path = list(image_path)
    if not image_path:
        return None
    return image_path[0]

post_stent['image_path'] = post_stent['USUBJID'].apply(lambda x: resolve_image_path(x))
post_stent.dropna(subset=['image_path'], inplace=True)

post_stent.replace({'\xa0\xa0\xa0\xa0\xa0\xa0\xa0\xa0.': np.nan}, inplace=False).to_csv('tabular_data/Post_Stent.csv', index=False)


# Let's build the dataloader
Let's start with Pre-IVL

In [3]:
import yaml
from utils import KFoldNNUNetTabularDataModule

with open('../pytorch/configs/fine_tune_config-regression.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

dataModule = KFoldNNUNetTabularDataModule(config=config)

In [4]:
dataModule.setup('fit')
train_dataloader = dataModule.train_dataloader()
for batch in train_dataloader:
    break


[Subject(Keys: ('MCS_MSCT', 'image', 'location'); images: 1), Subject(Keys: ('MCS_MSCT', 'image', 'location'); images: 1)]
dict_keys(['MCS_MSCT', 'image', 'location'])


In [8]:
print(batch.keys())
print(batch['image'].shape)
print(batch['MCS_MSCT'].shape)

dict_keys(['MCS_MSCT', 'image'])
torch.Size([2, 1, 512, 512, 384])
torch.Size([2])


In [None]:
for metric in config['data']['output_metrics']:
    print(metric, batch[0][metric])

In [None]:
batch[0]['image'][tio.DATA]

In [None]:
batches = []
for batch in train_dataloader:
    batches.append(batch)
    if len(batches) == 10:
        break

In [None]:
batches[1][0].plot()

In [None]:
for metric in config['data']['output_metrics']:
    print(metric, batches[1][0][metric])

## Exploration of the data

In [None]:
import pandas as pd

pre_ivl = pd.read_csv('tabular_data/Pre_IVL.csv')
post_ivl = pd.read_csv('tabular_data/Post_Stent.csv')
post_stent = pd.read_csv('tabular_data/Post_Stent.csv')

### Data counting percentage of non-NA values in the columns

In [None]:
100 * pre_ivl.count(axis=0) / pre_ivl.shape[0]

# data = [
#     'MCS_LA',
#     'MCS_SCA',
#     'MCS_MSCT',
# ]

In [None]:
100 * post_ivl.count(axis=0) / pre_ivl.shape[0]

In [None]:
100 * post_stent.count(axis=0) / pre_ivl.shape[0]

### Counting # of images in modality A available as inputs to predict metrics of modality B

In [None]:
for output_modality, output_modality_name in zip([pre_ivl, post_ivl, post_stent], ['Pre_IVL', 'Post_IVL', 'Post_Stent']):
    for input_modality in ['Pre_IVL', 'Post_IVL', 'Post_Stent']:
        count = output_modality[f'{input_modality}_image_path'].count()
        total = output_modality.shape[0]
        percentage = 100 * count / total
        print(f'{input_modality} -> {output_modality_name}')
        print(f'  {count} / {total} = {percentage:.2f}%')

### Plotting the distribution of the metrics

In [None]:
import seaborn as sns
import numpy as  np

In [None]:
pre_ivl.describe()

Selecting 3 columns that have the least NA-values

In [None]:
'''
MLAS_AS = Area stenosis (%)	
MLAS_SCA = Superficial calcium arc (°)	
MLAS_MSCT = Maximum superficial calcium thickness (mm)	
MCS_LA = Maximum calcium site -  Lumen area (mm2)	
MCS_AS = Maximum calcium site - Area stenosis (%)	
MCS_SCA = Maximum calcium site -  Superficial calcium arc (°)	
MCS_MSCT = Maximum calcium site -   Maximum superficial calcium thickness (mm)	
MCCS_LA = Maximum continuous calcium site - Lumen area (mm2)	
MCCS_AS = Maximum continuous calcium site - Area stenosis (%)	
MCCS_SCA = Maximum continuous calcium site - Superficial calcium arc (°)	
MCCS_MSCT = Maximum continuous calcium site - Maximum superficial calcium thickness (mm)	
MCCS_MINSCT = Maximum continuous calcium site - Minimum superficial calcium thickness (mm)	
MCCS_CSCA = Maximum continuous calcium site - Circumferential superficial calcium	
MCCS_CSCA_270 = Maximum continuous calcium site - Length of circumferential superficial calcium greater than equal to 270(mm)	
MCCS_CSCA_180 = Maximum continuous calcium site - Length of circumferential superficial calcium greater than equal to 180(mm)
FMSA_LA = Final minimum stent area site - Lumen area (mm2)	
FMSA_AS = Final minimum stent area site - Area stenosis (%)	
FMSA_SCA = Final minimum stent area site - Superficial calcium arc (°)	
FMSA_MSCT = Final minimum stent area site - Maximum superficial calcium thickness (mm)
'''

sns.pairplot(pre_ivl[[
    'MCS_LA',
    'MCS_SCA',
    'MCS_MSCT',
]])

Visually there is no outliter in the data, so it is safe to use the min-max normalization

In [None]:
min_max_norm = lambda x: (x - x.min()) / (x.max() - x.min())
    
min_max_normed_pre_ivl = pre_ivl[[
    'MCS_LA',
    'MCS_SCA',
    'MCS_MSCT',
]]
min_max_normed_pre_ivl = min_max_normed_pre_ivl.apply(min_max_norm, axis=0)
sns.pairplot(
    min_max_normed_pre_ivl
)

In [None]:
z_norm = lambda x: (x - x.mean()) / (x.std())
    
z_normed_pre_ivl = pre_ivl[[
    'MCS_LA',
    'MCS_SCA',
    'MCS_MSCT',
]]
z_normed_pre_ivl = z_normed_pre_ivl.apply(z_norm, axis=0)
sns.pairplot(
    z_normed_pre_ivl,
    corner=True
)

In [None]:
log_norm = lambda x: np.log(x)
    
log_normed_pre_ivl = pre_ivl[[
    'MCS_LA',
    'MCS_SCA',
    'MCS_MSCT',
]]
log_normed_pre_ivl = log_normed_pre_ivl.apply(log_norm, axis=0)
sns.pairplot(
    log_normed_pre_ivl
)

In [None]:
post_ivl.describe()