## Generate superpixel-based pseudolabels


### Overview

This is the third step for data preparation

Input: normalized images

Output: pseulabel label candidates for all the images

In [1]:
%reset
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import copy
import skimage

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.measure import label 
import scipy.ndimage.morphology as snm
from skimage import io
import matplotlib.pyplot as plt
import argparse
import numpy as np
import glob

import SimpleITK as sitk
import os

to01 = lambda x: (x - x.min()) / (x.max() - x.min())



Once deleted, variables cannot be recovered. Proceed (y/[n])? y


**Summary**

a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background

b. Generate superpixels as pseudolabels

**Configurations of pseudlabels**

```python
# default setting of minimum superpixel sizes
segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
# you can also try other configs
segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)
```


In [126]:
img_bname = f'../data/chaos_MR_train_T2_histv3/image_*.nii.gz' # input image names
imgs = glob.glob(img_bname)
out_dir = "../data/chaos_MR_train_T2_histv3/" # output directly


In [127]:
imgs

['../data/chaos_MR_train_T2_histv3/image_33.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_31.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_10.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_37.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_21.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_2.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_38.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_19.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_3.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_20.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_36.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_15.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_1.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_39.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_22.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_34.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_13.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_32.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_5.nii.gz

In [128]:
imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )

In [129]:
imgs

['../data/chaos_MR_train_T2_histv3/image_1.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_2.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_3.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_5.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_8.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_10.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_13.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_15.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_19.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_20.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_21.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_22.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_31.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_32.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_33.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_34.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_36.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_37.nii.gz',
 '../data/chaos_MR_train_T2_histv3/image_38.nii.gz'

In [133]:
MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting

# wrapper for process 3d image in 2d
def superpix_vol(img, method = 'fezlen', **kwargs):
    """
    loop through the entire volume
    assuming image with axis z, x, y
    """
    if method =='fezlen':
        seg_func = skimage.segmentation.felzenszwalb
    else:
        raise NotImplementedError
        
    out_vol = np.zeros(img.shape)
    for ii in range(img.shape[0]):
        if MODE == 'MIDDLE':
            segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
        else:
            raise NotImplementedError
        out_vol[ii, ...] = segs
        
    return out_vol

# thresholding the intensity values to get a binary mask of the patient
def fg_mask2d(img_2d, thresh = 1e-4 + 50): # change this by your need
    mask_map = np.float32(img_2d > thresh)
    
    def getLargestCC(segmentation): # largest connected components
        labels = label(segmentation)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        return largestCC
    if mask_map.max() < 0.999:
        return mask_map
    else:
        post_mask = getLargestCC(mask_map)
        fill_mask = snm.binary_fill_holes(post_mask)
    return fill_mask

# remove superpixels within the empty regions
def superpix_masking(raw_seg2d, mask2d):
    raw_seg2d = np.int32(raw_seg2d)
    lbvs = np.unique(raw_seg2d)
    max_lb = lbvs.max()
    raw_seg2d[raw_seg2d == 0] = max_lb + 1
    lbvs = list(lbvs)
    lbvs.append( max_lb )
    raw_seg2d = raw_seg2d * mask2d
    lb_new = 1
    out_seg2d = np.zeros(raw_seg2d.shape)
    for lbv in lbvs:
        if lbv == 0:
            continue
        else:
            out_seg2d[raw_seg2d == lbv] = lb_new
            lb_new += 1
    
    return out_seg2d
            
def superpix_wrapper(img, verbose = False):
    raw_seg = superpix_vol(img)
    fg_mask_vol = np.zeros(raw_seg.shape)
    processed_seg_vol = np.zeros(raw_seg.shape)
    for ii in range(raw_seg.shape[0]):
        if verbose:
            print("doing {} slice".format(ii))
        _fgm = fg_mask2d(img[ii, ...] )
        _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)
        fg_mask_vol[ii] = _fgm
        processed_seg_vol[ii] = _out_seg
    return fg_mask_vol, processed_seg_vol
        
# copy spacing and orientation info between sitk objects
def copy_info(src, dst):
    dst.SetSpacing(src.GetSpacing())
    dst.SetOrigin(src.GetOrigin())
    dst.SetDirection(src.GetDirection())
    # dst.CopyInfomation(src)
    return dst


def strip_(img, lb):
    img = np.int32(img)
    if isinstance(lb, float):
        lb = int(lb)
        return np.float32(img == lb) * float(lb)
    elif isinstance(lb, list):
        out = np.zeros(img.shape)
        for _lb in lb:
            out += np.float32(img == int(_lb)) * float(_lb)
            
        return out
    else:
        raise Exception

In [135]:
# Generate pseudolabels for every image and save them
for img_fid in imgs:
# img_fid = imgs[0]

    idx = os.path.basename(img_fid).split("_")[-1].split(".nii.gz")[0]
    im_obj = sitk.ReadImage(img_fid)

    out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj) )
    out_fg_o = sitk.GetImageFromArray(out_fg ) 
    out_seg_o = sitk.GetImageFromArray(out_seg )

    out_fg_o = copy_info(im_obj, out_fg_o)
    out_seg_o = copy_info(im_obj, out_seg_o)
    seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')
    msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')
    sitk.WriteImage(out_fg_o, msk_fid)
    sitk.WriteImage(out_seg_o, seg_fid)
    print(f'image with id {idx} has finished')


1 has finished
2 has finished
3 has finished
5 has finished
8 has finished
10 has finished
13 has finished
15 has finished
19 has finished
20 has finished
21 has finished
22 has finished
31 has finished
32 has finished
33 has finished
34 has finished
36 has finished
37 has finished
38 has finished
39 has finished
