Jupyter notebook to create the cylinder and ellipse fit from the 2D segmentation. 
Written by Dominik Waibel & Niklas Kiermeyer

Required Folder Structure

-- SHAPR_dataset
-- -- mask
-- -- Ellipse_fit
-- -- Cylinder_fit 

while the 2D segmentations are located in the mask folder

In [None]:
#import dependencies
import os 
import numpy as np
from skimage.io import imread, imsave
from skimage.measure import label, regionprops
from skimage import measure
import trimesh
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import trimesh
import cv2
import math
from skimage.transform import resize
import numpy.linalg as linalg
from pyellipsoid import drawing
from skimage.filters import threshold_otsu
from skimage.feature import shape_index
from scipy.ndimage.measurements import center_of_mass
from skimage.filters import gaussian
import seaborn as sns
import keras.backend as K
from scipy.ndimage.morphology import binary_dilation
from scipy.ndimage import gaussian_filter
import copy
from scipy.ndimage.morphology import binary_fill_holes
from matplotlib.patches import Ellipse

In [None]:
# set the path to the dataset folder
path = "./SHAPR_dataset//"

In [None]:
# normalize and threshold the data using Otsu's method: 
#https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_thresholding.html
def norm_thres(data): 
    maxd = np.max(data)
    data = np.nan_to_num(data / maxd)
    if np.max(data)  > 0:
        thresh = threshold_otsu(data)
        binary = data > thresh
    else: 
        binary = data
    return binary*1.0

In [None]:
# perform an allipse fit
def fitEllipse(cont):

    x=cont[:,0]
    y=cont[:,1]

    x=x[:,None]
    y=y[:,None]

    D=np.hstack([x*x,x*y,y*y,x,y,np.ones(x.shape)])
    S=np.dot(D.T,D)
    C=np.zeros([6,6])
    C[0,2]=C[2,0]=2
    C[1,1]=-1
    E,V=np.linalg.eig(np.dot(np.linalg.inv(S),C))

    n=np.argmax(np.abs(E))

    a=V[:,n]

    #-------------------Fit ellipse-------------------
    b,c,d,f,g,a=a[1]/2., a[2], a[3]/2., a[4]/2., a[5], a[0]
    num=b*b-a*c
    cx=(c*d-b*f)/num
    cy=(a*f-b*d)/num

    angle=0.5*np.arctan(2*b/(a-c))*180/np.pi
    up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g)
    down1=(b*b-a*c)*( (c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
    down2=(b*b-a*c)*( (a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a))
    a=np.nan_to_num(np.sqrt(abs(up/down1)))
    b=np.nan_to_num(np.sqrt(abs(up/down2)))

    params=[cx,cy,a,b,angle]

    return params 

In [None]:
# fit an ellipse to the 2D segmentation mask
def get_ellipse(mask): 
    contours, hierarchy = cv2.findContours(mask, 1, 2)
    if len(contours) >= 1:
        cnt = contours[0]
        if len(cnt) > 5:
            #try:
            if 1 ==1 :
                ell_params = fitEllipse(cnt[:,0])
                x, y, MA, ma, angle = ell_params
                image_shape = (64,64,64)
                # Define an ellipsoid, axis order is: X, Y, Z
                ell_center = (x, y, 32)
                ell_radii = (MA, ma, (MA+ma)/(4))
                ell_angles = np.deg2rad([0, 0, angle.real])
                ellipse3d = drawing.make_ellipsoid_image(image_shape, ell_center, ell_radii, ell_angles)
                return ellipse3d*1.0
        else:
            return np.zeros((64, 64,64))
    else:
        return np.zeros((64, 64,64))

In [None]:
# fit a cylinder to the 2D segmentation
def get_cylinder(mask):
    contours, hierarchy = cv2.findContours(mask, 1, 2)
    if len(contours) >= 1:
        cnt = contours[0]
        if len(cnt) > 3:
            #try:
            ell_params = fitEllipse(cnt[:,0])
            x, y, MA, ma, angle = ell_params

            mask3Dinner = mask[np.newaxis,...]
            z_radius = int(round(MA+ma)/(2))
            if z_radius > 64:
                z_radius = 64 
            if z_radius < 1:
                z_radius = 1
            mask3Dinner = resize(mask3Dinner,(z_radius,64,64), preserve_range=True)
            mask3d = np.zeros((64,64,64))
            mask3d[32-int(np.shape(mask3Dinner)[0]/2):32-int(np.shape(mask3Dinner)[0]/2)+int(np.shape(mask3Dinner)[0]),:,:] = mask3Dinner
            return mask3d    
        else:
            return np.zeros((64, 64,64))
    else:
        return np.zeros((64, 64,64))

In [None]:
files = os.listdir(path)
print("found", len(files), "files")

for index, file in enumerate(files): 
    print(index, file)
    #get the 2D segmentation ("mask")
    mask = imread(test_path + "/mask/"+ file)
    mask = np.array(binary_fill_holes(mask).astype("uint8"))
    #perform the cylidner and ellipse fit
    ellipse = np.nan_to_num(get_ellipse(mask))
    cylinder = norm_thres(np.nan_to_num(get_cylinder(mask)))
    #save the cylinder and ellipse to the respective folders
    imsave(test_path + "/Cylinder_fit/"+file, (cylinder*255.).astype("uint8"))
    imsave(test_path + "/Ellipse_fit/"+file, (ellipse*255.).astype("uint8"))
    