# Post-processing of nnUNet predictions of SKM-TEA data
## First check connected components

In [None]:
# imports
import cc3d
import json
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import SimpleITK as sitk

In [None]:
from acvl_utils.morphology.morphology_helper import generic_filter_components, remove_components_cc3d, cc3d_label_with_component_sizes

In [None]:
# data directory path
nnunet_data_dir = '../../data/nnUNet_raw/Dataset361_Menisci/'

In [None]:
# Let's look at the ground truth data
# path to ground truth data
gt_path = nnunet_data_dir + 'labels_all_skmtea/'
im_path = nnunet_data_dir + 'images_all_skmtea/'
pred_path = nnunet_data_dir + 'zscore_preds_all_skmtea/'

In [None]:
# get list of files in the directory
gt_files = os.listdir(gt_path)
gt_files.sort()
gt_files[:5]

In [None]:
im_files = os.listdir(im_path)
im_files.sort()
im_files[:5]

In [None]:
pred_files = os.listdir(pred_path)
pred_files.sort()
pred_files[:5]

In [None]:
# load in each image, and return labels and properties
smallest_components = []

for file in gt_files:
    # load in the image
    gt = sitk.ReadImage(gt_path + file)
    gt = sitk.GetArrayFromImage(gt)
    # get the properties of the image before and after removing components
    labels, props = cc3d_label_with_component_sizes(gt, connectivity=26)
    labels_filtered, props_filtered = cc3d_label_with_component_sizes(remove_components_cc3d(gt, 100, threshold_type='min', connectivity=26), connectivity=26)

    """
    # print the properties
    if len(props) > 2:
        print(f'File: {file}')
        print(f'Number of components before: {len(props)}')
        print(f'Number of components after: {len(props_filtered)}')
        print(f'Properties before: {props}')
        print(f'Properties after: {props_filtered}')
        print('----------------------------------------')"""
    
    # cycle through props dictionary and print smallest component
    smallest_components.append(min(list(props_filtered.values())))

In [None]:
# print smallest components
for i, comp in enumerate(smallest_components):
        print(i, comp)

In [None]:
# load in each image, and return labels and properties
smallest_pred_components = []

for file in pred_files:
    # load in the image
    pred = sitk.ReadImage(pred_path + file)
    pred = sitk.GetArrayFromImage(pred)
    # get the properties of the image before and after removing components
    labels, props = cc3d_label_with_component_sizes(gt, connectivity=26)
    labels_filtered, props_filtered = cc3d_label_with_component_sizes(remove_components_cc3d(pred, 100, threshold_type='min', connectivity=26), connectivity=26)

    """
    # print the properties
    if len(props) > 2:
        print(f'File: {file}')
        print(f'Number of components before: {len(props)}')
        print(f'Number of components after: {len(props_filtered)}')
        print(f'Properties before: {props}')
        print(f'Properties after: {props_filtered}')
        print('----------------------------------------')"""
    
    # cycle through props dictionary and print smallest component
    smallest_pred_components.append(min(list(props_filtered.values())))

In [None]:
# print smallest components
for i, comp in enumerate(smallest_pred_components):
        print(i, comp)

In [None]:
# Pick a file to look at
file = pred_files[8]

# load in the image
gt = sitk.ReadImage(pred_path + file)
gt = sitk.GetArrayFromImage(gt)

# get the labels and properties
labels, props = cc3d_label_with_component_sizes(gt, connectivity=26)

# print the properties
print(f"File {file} has {len(props)} components: {props}")

# plot the labels (mask with 1 for first component, 2 for second etc)
plt.figure(figsize=(10, 10))
plt.imshow(np.sum(gt, axis=1))
plt.title(file)
plt.show()

In [None]:
# Let's look at the image
# load in the image
im = sitk.ReadImage(im_path + im_files[52])
im = sitk.GetArrayFromImage(im)

# plot the image and mask side by side
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(im[40,...])
plt.title('Image')
plt.subplot(1, 2, 2)
plt.imshow(gt[40,...])
plt.title('Mask')
plt.show()

In [None]:
# damn, cropped too low. Will need to correct and re-predict :(