# Corrections, Filtering and Segmentation Notebook

## This notebook is tested with the Offshore Langseth Dataset

## The workflow is as follows:
* **Beam hardening correction on each sample in this set**
    * Sample plots: 
        * Before (i) and after (ii) radial correction on one slice, and their radial profiles (iii). Radial profile of average slice through that particular section (before and after) (iv).
* **Histogram equalization**
    * Sample plots:
        * Before (i) and after (ii) histogram equalization on one slice, and their histograms (iii). Histogram of an average slice through that particular section (before and after) (iv).


* **Two different filters after previous steps:**
    1. **Median filter**
    * Sample plots:
        * Before (i) and after (ii) median filter on one slice, and their radial profiles (iii). Histogram of an average slice through that particular section (before and after) (iv).
    2. **Anisotropic diffusion filter**
    * Sample plots:
        * Before (i) and after (ii) anisotropic diffusion filter on one slice, and their radial profiles (iii). Histogram of an average slice through that particular section (before and after) (iv).

* **Various Segmentation Methods**
* **Shape analysis**

In [None]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('error', category=DeprecationWarning)

import skimage
import os
import glob
import time
import psutil
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import random
from IPython.display import HTML
import concurrent, multiprocessing
import re
import Worflow_functions as SRAF
import copy
import pickle
import cc3d
import pyvista as pv 

time0 = time.time()

if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')

def get_occuppied_mem():
    mem = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    print(f"{mem:12.4f} MB")
    

## Reading the data

In [None]:
%%time
# sample_names = ['sample_1',
#                 'sample_2',
#                 'sample_3',
#                 ...]

sample_name='sample_1'


# Uncomment to see the used memory
# mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)

if __name__=='__main__':
    with concurrent.futures.ProcessPoolExecutor() as executor:
        #reads the folder names
        info_list_for_dict = [executor.submit(SRAF.prepare_sample,sample_name) for sample_name in sample_names]
        #reads the data
        dict_data = [executor.submit(SRAF.create_dict,info.result()) for info in info_list_for_dict]

# Uncomment to see the used memory
    # for sample in concurrent.futures.as_completed(dict_data):
    #     mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    # print(f"{mem_after - mem_before:12.4f} MB")

In [None]:
%%time
dict_data = list([data.result() for data in dict_data])
# mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
# print(f"{mem_after:12.4f} MB")

whole_data_dict = {}
for i,dict_item in enumerate(dict_data):
    item = dict_item     
    key = sample_names[i].replace('-','_')
    if key not in whole_data_dict:
        whole_data_dict[key] = item

# mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
# print(f"{mem_after:12.4f} MB")

# Modifies the data by clipping the first and last 50 slices
for sample in whole_data_dict:
    SRAF.modify_dict(whole_data_dict[sample])
        
# mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
# print(f"{mem_after:12.4f} MB")

## Checking the sections to see if all of them are imported

In [None]:
for data in whole_data_dict:
    print(data, ':', list(whole_data_dict[data].keys()))

# Alternative:
# for data in whole_data_dict:
#     print(data, *list(whole_data_dict[data].keys()), sep = "\n  ├──")

## Checking the section dimensions

In [None]:
# Dimensions after cropping the first and last 50 slices
counter=0
for data in whole_data_dict:
    for section in whole_data_dict[data].keys():
        print(np.array(whole_data_dict[data][section]).shape)
        counter+=1
print(f'\nTotal number of slices: {counter}')

## Beam Hardening Correction

In [None]:
# Reloading the packages, not necessary at this moment.
# import importlib
# importlib.reload(SRAF)

In [None]:
%%time
corrected_sample={}
for sample in whole_data_dict:
    corrected_sample[sample] = SRAF.radially_correct_section(whole_data_dict[sample])
# get_occuppied_mem()

# Alternative

# get_occuppied_mem()
# if __name__=='__main__':
#     with concurrent.futures.ProcessPoolExecutor() as executor2:
#         #applies radial correction
#         corrected_section = [executor2.submit(SRAF.radially_correct_section, whole_data_dict[sample_name.replace('-','_')], 45) for section in whole_data_dict[sample_name.replace('-','_')]]       
# #     for sample in concurrent.futures.as_completed(dict_data):
        
# get_occuppied_mem()

In [None]:
SRAF.visualize_radial_correction(sample_name=sample_name, 
                                 original_sample=whole_data_dict[sample_name.replace('-','_')], 
                                 corrected_sample=corrected_sample[sample_name.replace('-','_')],
                                 profile_angle_degrees=45,
                                 slice_to_plot=100,
                                 save=True)

In [None]:
get_occuppied_mem()

## Histogram Equalization

In [None]:
%%time
for sample in corrected_sample:
    for section in corrected_sample[sample]:
        for i in range(0,len(corrected_sample[sample][section])):
            corrected_sample[sample][section][i,:,:] = skimage.exposure.equalize_adapthist(corrected_sample[sample][section][i,:,:].astype(int))
        print(f'{sample}, {section} completed.')
get_occuppied_mem()

In [None]:
SRAF.visualize_filter(sample_name=sample_name, 
                             original_sample=whole_data_dict[sample_name.replace('-','_')], 
                             corrected_sample=corrected_sample[sample_name.replace('-','_')],
                             mask_radius=190,
                             filter_name='Adapt. Hist. Eq.',
                             slice_to_plot=100)

## Median Filtering

In [None]:
%%time
if __name__=='__main__':
    sample = sample_name.replace('-','_')
    with concurrent.futures.ProcessPoolExecutor() as executor3:
        #reads the folder names
        med_sections_Futures = [executor3.submit(skimage.filters.median,corrected_sample[sample][section],skimage.morphology.cube(3)) for section in corrected_sample[sample]]

In [None]:
median_filtered_sample=copy.deepcopy(corrected_sample)

In [None]:
for i,section in enumerate(median_filtered_sample[sample_name.replace('-','_')]):
    median_filtered_sample[sample_name.replace('-','_')][section] = med_sections_Futures[i].result()

In [None]:
SRAF.visualize_filter(sample_name=sample_name, 
                             original_sample=corrected_sample[sample_name.replace('-','_')], 
                             corrected_sample=median_filtered_sample[sample_name.replace('-','_')],
                             mask_radius=190,
                             filter_name='Median Filter',
                             slice_to_plot=100)

## Anisotropic diffusion

In [None]:
%%time
if __name__=='__main__':
    sample = sample_name.replace('-','_')
    with concurrent.futures.ProcessPoolExecutor() as executor3:
        #reads the folder names
        anid_sections_Futures = [executor3.submit(SRAF.anisotropic_diffusion,corrected_sample[sample][section],niter=5) for section in corrected_sample[sample]]

In [None]:
anid_filtered_sample=copy.deepcopy(corrected_sample)
for i,section in enumerate(anid_filtered_sample[sample_name.replace('-','_')]):
    anid_filtered_sample[sample_name.replace('-','_')][section] = anid_sections_Futures[i].result()

In [None]:
%%time
SRAF.visualize_filter(sample_name=sample_name, 
                      original_sample=corrected_sample[sample_name.replace('-','_')], 
                      corrected_sample=anid_filtered_sample[sample_name.replace('-','_')],
                      mask_radius=190,
                      filter_name='Ani. Diff. Filter 10 iter',
                      slice_to_plot=100)

In [None]:
import pickle

# save dictionary to .pkl file
with open(f'{sample_name}_filtered_sample_ready_for_segmentation.pkl', 'wb') as fp:
    pickle.dump(anid_filtered_sample, fp)

## Median vs Anisotropic Diffusion Comparison:

In [None]:
%%time
SRAF.compare_filters(sample_name, 
                     data1=median_filtered_sample[sample_name.replace('-','_')], 
                     data1_name='Median',
                     data2=anid_filtered_sample[sample_name.replace('-','_')], 
                     data2_name='Ani. Diff.',
                     mask_radius=190,
                     slice_to_plot=100,
                     save=True)

In [None]:
get_occuppied_mem()

## Thresholding and Segmenting

Anisotropic diffusion is used. It can be modified to be median as well.

In [None]:
# Import the filtered data if importing locally

In [None]:
anid_filtered_sample=copy.deepcopy(corrected_sample)
for i,section in enumerate(anid_filtered_sample[sample_name.replace('-','_')]):
    anid_filtered_sample[sample_name.replace('-','_')][section] = anid_sections_Futures[i].result()

# Alternative
# sample = sample_name.replace('-','_')

# for section in anid_filtered_sample[sample]:
#     anid_filtered_sample[sample][section] = SRAF.mask_section(anid_filtered_sample[sample][section],
#                                                                            mask_radius=180)

In [None]:
%%time
otsu=copy.deepcopy(anid_filtered_sample)

sample = sample_name.replace('-','_')

# Assing number of classes
classes = 4

for section in anid_filtered_sample[sample]:
    otsu[sample][section] = skimage.filters.threshold_multiotsu(anid_filtered_sample[sample][section],classes=classes)

In [None]:
# See thresholds
otsu

In [None]:
%%time
thresholded_sample = copy.deepcopy(anid_filtered_sample)

sample = sample_name.replace('-','_')

for section in anid_filtered_sample[sample]:
    thresholded_sample[sample][section] = np.digitize(anid_filtered_sample[sample][section],bins=otsu[sample][section])
    

In [None]:
# If manual adjustments are needed, similar scheme can be used:
# for section in thresholded_sample[sample]:
#     thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==2]=5
#     thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==3]=5
#     thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==4]=5

In [None]:
# Plot
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15),sharex=True, sharey=True)
ax[0].imshow(anid_filtered_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
im = ax[1].imshow(thresholded_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
plt.tight_layout()

In [None]:
for section in thresholded_sample[sample]:
    thresholded_sample[sample][section] = skimage.morphology.erosion(thresholded_sample[sample][section],
                                                                     footprint=skimage.morphology.cube(2))
    
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15),sharex=True, sharey=True)
ax[0].imshow(anid_filtered_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
im = ax[1].imshow(thresholded_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
plt.tight_layout()

## Preparing sample for measuring properties

In [None]:
%%time
sample = sample_name.replace('-','_')

for section in thresholded_sample[sample]:
    thresholded_sample[sample][section] = SRAF.mask_section(thresholded_sample[sample][section],
                                                            mask_radius=180)

In [None]:
plt.imshow(thresholded_sample[sample]['Section 1'][100,:,:])
plt.colorbar()

Inversing the colors and assigning 0 to the masked regions since we don't want their properties.

In [None]:
for section in thresholded_sample[sample]:
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==0]=4
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==1]=5
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==2]=6
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==3]=7
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==1.17122018]=0
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==4]=1
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==5]=2
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==6]=3
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==7]=4

In [None]:
for section in thresholded_sample[sample]:
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==2]=1
for section in thresholded_sample[sample]:
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==3]=2
for section in thresholded_sample[sample]:
    thresholded_sample[sample_name.replace('-','_')][section][thresholded_sample[sample_name.replace('-','_')][section]==4]=3

In [None]:
plt.imshow(thresholded_sample[sample]['Section 1'][100,:,:])
plt.colorbar()

## Saving and checking the saved data:

In [None]:
import pickle

# save dictionary to .pkl file
with open(f'{sample_name}_otsu_segmented.pkl', 'wb') as fp:
    pickle.dump(thresholded_sample, fp)

In [None]:
with open(f'{sample_name}_otsu_segmented.pkl', 'rb') as fp:
    loaded_sample = pickle.load(fp)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15,15),sharex=True, sharey=True)
ax[0].imshow(anid_filtered_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
im = ax[1].imshow(loaded_sample[sample_name.replace('-','_')]['Section 1'][100,:,:],interpolation=None)
plt.tight_layout()

# Different segmentation methods
## i. Segmenting with region merging

In [None]:
%%time

def weight_boundary(graph, src, dst, n):
    """
    Handle merging of nodes of a region boundary region adjacency graph.

    This function computes the `"weight"` and the count `"count"`
    attributes of the edge between `n` and the node formed after
    merging `src` and `dst`.

skimage.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the "weight" and "count" attributes to be
        assigned for the merged node.

    """
    default = {'weight': 0.0, 'count': 0}

    count_src = graph[src].get(n, default)['count']
    count_dst = graph[dst].get(n, default)['count']

    weight_src = graph[src].get(n, default)['weight']
    weight_dst = graph[dst].get(n, default)['weight']

    count = count_src + count_dst
    return {
        'count': count,
        'weight': (count_src * weight_src + count_dst * weight_dst)/count
    }


def merge_boundary(graph, src, dst):
    """Call back called before merging 2 nodes.

    In this case we don't need to do any computation here.
    """
    pass

sample = sample_name.replace('-','_')


masked_sample = copy.deepcopy(anid_filtered_sample)
for section in anid_filtered_sample[sample]:
    masked_sample[sample][section] = SRAF.mask_section(anid_filtered_sample[sample][section],
                                                                           mask_radius=180) 

In [None]:
# 1 sample:

# for section in anid_filtered_sample[sample]:
#     otsu[sample][section] = skimage.filters.threshold_multiotsu(anid_filtered_sample[sample][section],classes=4)

gimg = masked_sample[sample]['Section 1'][100,:,:]
img_color=skimage.color.gray2rgb(gimg)

fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12,12),sharex=True, sharey=True)
ax[0,0].imshow(gimg,interpolation=None)
ax[0,0].set_title('Filtered Sample')

labels = skimage.segmentation.slic(img_color, compactness=30, n_segments=2000, start_label=1)
edges = skimage.filters.sobel(gimg)
edges_rgb = skimage.color.gray2rgb(edges)

g = skimage.graph.rag_boundary(labels, edges)
lc = skimage.graph.show_rag(
    labels, g, edges_rgb, img_cmap=None, edge_cmap='viridis', edge_width=1.2,ax=ax[0,1])

ax[0,1].set_title('Region Adjacency Graph')

labels2 = skimage.graph.merge_hierarchical(labels, g, thresh=0.06, rag_copy=False,
                                   in_place_merge=True,
                                   merge_func=merge_boundary,
                                   weight_func=weight_boundary)

skimage.graph.show_rag(labels, g, img_color, ax = ax[1,0])
ax[1,0].set_title('RAG after hierarchical merging')

# out = skimage.color.label2rgb(labels2, img_color, kind='avg', bg_label=0)
im = ax[1,1].imshow(labels2, interpolation=None, vmin=np.min(np.unique(labels2)), vmax = np.max(np.unique(labels2)))
ax[1,1].set_title('Final segmentation')
# plt.tight_layout()
plt.colorbar(lc, ax=ax[0,1])
plt.colorbar(im, ax=ax[1,1])
plt.show()

In [None]:
np.unique(labels2)

In [None]:
out_mod = copy.deepcopy(labels2)
cond1 = out_mod>=18
cond2 = (out_mod<18)&(out_mod>12)
cond3 = out_mod<=12
out_mod = np.select(condlist=[cond1,cond2,cond3],choicelist=[2,1,0])



fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8,8))
ax.imshow(out_mod, interpolation=None)

#### We aren't happy with the results

## ii. Random Walker Segmentation

In [None]:
masked_sample = copy.deepcopy(anid_filtered_sample)
for section in anid_filtered_sample[sample]:
    masked_sample[sample][section] = SRAF.mask_section(anid_filtered_sample[sample][section],
                                                                           mask_radius=180) 

In [None]:
data = masked_sample[sample]['Section 1']
np.max(data)

In [None]:
data = masked_sample[sample]['Section 1']

markers = np.zeros(data.shape, dtype=np.uint)

markers[(data > 0)&(data < 0.2)] = 1
markers[(data > 0.5)&(data < 0.7)] = 2
markers[(data > 0.9)&(data < 1)] = 3
markers[data == 1.17122018] = 0

# Run random walker algorithm
labels = skimage.segmentation.random_walker(data, markers, beta=10, mode='bf')

### -- Computationally too expensive

## iii. Watershed

In [None]:
mesh = pv.examples.load_random_hills()
arrows = mesh.glyph(scale='Normals', orient='Normals', tolerance = 0.05)
pv.set_jupyter_backend('trame')
p = pv.Plotter()
p.add_mesh(arrows, color='black')
p.add_mesh(mesh, scalars='Elevation', cmap = 'terrain', smooth_shading=True)
p.show()

In [None]:
# Load Data
sample_name = sample_name.replace('-','_')
with open(f'{sample_name}_filtered_sample_ready_for_segmentation.pkl', 'rb') as fp:
    loaded_sample = pickle.load(fp)

In [None]:
np.unique(loaded_sample[sample_name.replace('-','_')]['Section 1'][100:200,100:200,100:200])

In [None]:
%%time
sample = sample_name.replace('-','_')

for section in loaded_sample[sample]:
    loaded_sample[sample][section] = SRAF.mask_section(loaded_sample[sample][section],
                                                                           mask_radius=180)

In [None]:
# Select sample images to try watershed
image = loaded_sample[sample_name.replace('-','_')]['Section 1'][100,:,:]

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5,5),sharex=True, sharey=True)
im = ax.imshow(image,interpolation=None)
plt.tight_layout()
plt.colorbar(mappable=im)

In [None]:
import scipy
# fig, axes = plt.subplots(nrows=2, ncols=3,figsize=(18,15))
fig, axes = plt.subplots(nrows=3, ncols=4,figsize=(25,20))

ax=axes.flatten()
for axis in ax:
    axis.axis('off')
ax[0].imshow(image, interpolation=None)

# 1. Gradients
scharr = skimage.filters.scharr(image)
ax[1].imshow(scharr, interpolation=None)

# 2. Dilation (Opening-Closing)
dilated = skimage.morphology.dilation(image, footprint=skimage.morphology.disk(3))
ax[2].imshow(dilated, interpolation=None)

# 2. Reconstructuon by dilation 
# https://scikit-image.org/docs/stable/api/skimage.morphology.html#skimage.morphology.reconstruction
# seed = np.copy(image)
# seed[1:-1, 1:-1] = image.max()

mask_radius = 160
center = (image.shape[0]/2,image.shape[1]/2)
x, y = np.meshgrid(np.arange(470), np.arange(470))

# Calculating the distances from the center
r = np.sqrt((center[0] - x)**2 + (center[1] - y)**2)

# Apply np.select to get the diff values
seed = np.where(r == mask_radius, image,image.max())


rec = skimage.morphology.reconstruction(seed=seed, mask=image, method='erosion')
ax[3].imshow(rec, interpolation=None)

fgm = rec-image
ax[4].imshow(fgm, interpolation=None)

ax[5].imshow(image, cmap='gray', interpolation=None)
ax[5].imshow(fgm, alpha = 0.8, interpolation=None)
fgm_rescaled = skimage.exposure.rescale_intensity(fgm, in_range=(0,0.52774596),out_range=(0,1))
otsu = skimage.filters.threshold_multiotsu(fgm_rescaled,classes=2)

markers = np.digitize(image,bins=[0,0.5])
distance = scipy.ndimage.distance_transform_edt(fgm)
ax[6].imshow(distance, interpolation=None)

ax[7].imshow(markers, interpolation=None)

local_maxi = skimage.feature.peak_local_max(fgm, min_distance=5)

peaks_mask = np.zeros_like(fgm, dtype=bool)
peaks_mask[local_maxi] = True
ax[8].imshow(peaks_mask, interpolation=None)

labels = skimage.segmentation.watershed(-distance,markers = peaks_mask ,watershed_line=True)
ax[9].imshow(labels, interpolation=None)

# fgm_dilated = skimage.morphology.dilation(fgm, footprint=skimage.morphology.disk(3))
# ax[6].imshow(image, cmap='gray')
# ax[6].imshow(fgm_dilated, alpha = 0.8)



#     markers = np.zeros(image.shape, dtype=np.uint)
#     markers[(image > 0)&(image < thresholds_array[i][0])] = 1
#     markers[(image > thresholds_array[i][0])&(image < thresholds_array[i][1])] = 2
#     markers[(image > thresholds_array[i][1])&(image < thresholds_array[i][2])] = 3
#     markers[(image > thresholds_array[i][2])] = 4

#     regions = np.digitize(image, bins=thresholds_array[i])

#     regions = regions + np.ones(regions.shape)
# #     labels = skimage.segmentation.watershed(image, regions)

#     sobel = skimage.filters.sobel(image)  
#     img_sobel_digitized = 1- skimage.morphology.binary_dilation(np.digitize(sobel,[0.5]))

#     np.place(regions,img_sobel_digitized == 0,0)
    
#     labels = skimage.segmentation.watershed(image, regions)
#     footprint = skimage.morphology.square(3)
#     labels = skimage.morphology.closing(labels, footprint)
#     ax[i+3].imshow(labels, interpolation=None)


In [None]:
np.unique(markers)
import cv2

In [None]:
import cv2
img_8 = image.astype(np.uint8)
# otsu = skimage.filters.threshold_multiotsu(fgm_rescaled,classes=2)
thresholded = np.digitize(image,bins=[0.5])
thresholded = skimage.util.invert(thresholded)
thresholded = skimage.exposure.rescale_intensity(thresholded,out_range=(0,255))
plt.imshow(thresholded, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

# noise removal
kernel = np.ones((3,3),np.uint8)
opening = cv2.morphologyEx(thresholded,cv2.MORPH_OPEN,kernel, iterations = 2)
plt.imshow(opening, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

# sure background area
sure_bg = cv2.dilate(thresholded,kernel,iterations=1)
plt.imshow(sure_bg, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

# Finding sure foreground area
distance = scipy.ndimage.distance_transform_edt(sure_bg)
# dist_transform = cv2.distanceTransform(thresholded,cv2.DIST_L2,3)
plt.imshow(distance, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

otsu = skimage.filters.threshold_otsu(distance)
sure_fg = np.digitize(distance,bins=[otsu])
plt.imshow(sure_fg, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

# Finding unknown regions
unknown = sure_bg - sure_fg

plt.imshow(unknown, cmap='gray', interpolation=None)
plt.colorbar()
plt.show()

# Marker labelling
# Connected Components determines the connectivity of blob-like regions in a binary image.
sure_fg = sure_fg.astype(np.uint8)

ret, markers = cv2.connectedComponents(sure_fg)

# Add one to all labels so that sure background is not 0, but 1
markers = markers+1

# Now, mark the region of unknown with zero
markers[unknown==255] = 0
plt.imshow(markers, interpolation=None)
plt.colorbar()
plt.show()


# markers = cv2.watershed(col_img.astype(np.uint8),col_mark.astype(np.uint8))
watershed = skimage.segmentation.watershed(image=image,
                                         markers=markers,
                                        watershed_line=True)
plt.imshow(image, interpolation=None)
plt.imshow(watershed, interpolation=None, alpha = 0.5)
plt.colorbar()
plt.show()

plt.imshow(image, interpolation=None)
plt.colorbar()
plt.show()


In [None]:
np.sum(watershed==markers)

In [None]:
470*470

## Measuring properties
### A kernel restart is suggested for clearing the memory and only uploading the segmented data

In [None]:
import cc3d
# https://github.com/seung-lab/connected-components-3d

In [None]:
sample_name = sample_name.replace('-','_')
with open(f'{sample_name}_otsu_segmented.pkl', 'rb') as fp:
    loaded_sample = pickle.load(fp)

In [None]:
# Isolating the Fracture Plane
isolation=loaded_sample[sample_name.replace('-','_')]['Section 1']
isolation = isolation * (isolation == 1)
plt.imshow(isolation[100,:,:], interpolation=None)
plt.colorbar()
print(np.unique(isolation))

In [None]:
# Get a labeling of the k largest objects in the image.
# The output will be relabeled from 1 to N.

labels_out, N = cc3d.largest_k(
  isolation, k=15,  # or thresholded sample if the whole workflow is run
  connectivity=26, delta=0,
  return_N=True,
)

# labels_in *= (labels_out > 0) # to get original labels

In [None]:
# getting the voxel counts per component
index,counts = np.unique(labels_out,return_counts=True)
combined = np.array([index,counts])
combined = combined.T
combined[combined[:,1].argsort()][::-1]
combined

In [None]:
# creating the array for component properties 
no_of_props_to_collect = 3
props_array = np.zeros((max(index),no_of_props_to_collect))
  
for segid in range(1, N+1):
    extracted_image = labels_out * (labels_out == segid)
    props = skimage.measure.regionprops(extracted_image)
    props_array[segid-1,0]=props[0].area
    props_array[segid-1,1]=props[0].axis_minor_length
    props_array[segid-1,2]=props[0].feret_diameter_max

In [None]:
# Creating a dataframe for props
df = pd.DataFrame(data=props_array)
df.rename(columns={0:'Volume',1:'Axis_min',2:'Feret_max'}, inplace=True)
df['Shape_factor'] = df['Axis_min']/df['Feret_max']
df

In [None]:
# DOES NOT WORK YET, SO AN ALTERNATIVE VISUALIZATION IS PROVIDED IN THE NEXT CELL

# #cmap based on shape factor
# shape_colors=np.zeros((len(df['Shape_factor']),4))
# for i in range(0,3):
#     shape_colors[:,i] = df['Shape_factor'].values
# shape_colors[:,0:4]=1
# pv_lookuptable = pv.LookupTable(values=shape_colors,
#                                 scalar_range=(0,1))
# #lookup table does not work yet


### Shape colored

In [None]:
condlist = np.arange(1,len(df)+1)
condlist = [labels_out==condlist[i] for i in range(0,len(df))]
choicelist = list(df['Shape_factor'])

vol_shape_modified = np.select(condlist=condlist, choicelist=choicelist)

In [None]:
# Saving Colored Components for plotting in spyder

with open(f'{sample_name}_section_1_shapes_colored.pkl', 'wb') as fp:
    pickle.dump(vol_shape_modified, fp)

### Volume colored

In [None]:
condlist = np.arange(1,len(df)+1)
condlist = [labels_out==condlist[i] for i in range(0,len(df))]
choicelist = list(df['Volume'])

vol_vol_modified = np.select(condlist=condlist, choicelist=choicelist)

In [None]:
# Saving Colored Components for plotting in spyder

with open(f'{sample_name}_section_1_volumes_colored.pkl', 'wb') as fp:
    pickle.dump(vol_vol_modified, fp)

In [None]:
total_vol = np.sum(df['Volume'])
normalized_vols = df['Volume']/total_vol

In [None]:
condlist = np.arange(1,len(df)+1)
condlist = [labels_out==condlist[i] for i in range(0,len(df))]
choicelist = list(normalized_vols)

vol_vol_modified_normalized = np.select(condlist=condlist, choicelist=choicelist)

In [None]:
# Saving Colored Components for plotting in spyder

with open(f'{sample_name}_section_1_volumes_colored.pkl', 'wb') as fp:
    pickle.dump(vol_vol_modified, fp)

## Histogram analysis for number of components