In [None]:
import sys
import os

# Path to your user-installed Python packages (adjust this path accordingly)
user_site_packages = os.path.expanduser('~/.local/lib/python3.10/site-packages')

# Add the directory to sys.path temporarily
if user_site_packages not in sys.path:
    sys.path.insert(0, user_site_packages)

# Importing nilearn and itk after sys.path modification
import nilearn
from nilearn import image, plotting
from nilearn.input_data import NiftiMasker

import itk
from distutils.version import StrictVersion as VS

# Optionally remove from sys.path after use
if user_site_packages in sys.path:
    sys.path.remove(user_site_packages)

# Checking ITK version requirement
if VS(itk.Version.GetITKVersion()) < VS("5.0.0"):
    print("ITK 5.0.0 or newer is required.")
    sys.exit(1)

# Proceed with the rest of your imports and configuration
import pip
from time import perf_counter
import warnings
warnings.filterwarnings("ignore")

import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm

from sklearn.decomposition import FactorAnalysis
from sklearn import cluster

from skimage import exposure
from skimage.draw import line_nd, line_aa
from skimage.transform import rescale, resize
from skimage.filters import gaussian
from skimage import morphology
from skimage.filters import threshold_otsu, threshold_local
from skimage.morphology import binary_dilation

from scipy import ndimage as ndi
from scipy.ndimage import gaussian_filter

from numba import njit, jit, prange

mpl.rcParams['figure.figsize'] = [15, 15]
mpl.rcParams.update({'font.size': 22})
np.set_printoptions(formatter={'float_kind':'{:0.4f}'.format})


import os  # Import the os module at the beginning of your script


###############################################
### Set these:
###############################################
root = "/exports/gorter-hpc/users/ninafultz/highres_pvs/20240307_PVS3"  # Root folder of your project
out_folder = os.path.join(root, "segment/")
biasfield_folder = os.path.join(root, "biasfield/")

# Change current working directory to biasfield_folder
os.chdir(biasfield_folder)

input_img_url = os.path.join(biasfield_folder, "merged_last_three_DelRec_-_pvs_70slices_ME.nii.gz")
input_mask_url = os.path.join(biasfield_folder, "combined_mask.nii.gz")
output_prefix = os.path.join(out_folder, "qsm_vasc_")
out_hessian = output_prefix + "_hessian.nii.gz"

in_image = itk.imread(input_img_url, itk.F)

if not os.path.exists(out_folder):
    os.makedirs(out_folder) 

try:
    # Attempt to read the image using itk.imread
    in_image = itk.imread(input_img_url, itk.F)

    # Get voxel size and size of the image
    voxel_size = in_image.GetSpacing()
    size = in_image.GetLargestPossibleRegion().GetSize()

    # Example: Print out voxel size and image size
    print(f"Voxel Size: {voxel_size}")
    print(f"Image Size: {size}")

except Exception as e:
    print(f"Error reading image: {e}")
    
    
in_image   = itk.imread(input_img_url, itk.F)               # don't touch this
voxel_size = in_image.GetSpacing()                          # don't touch this
size       = in_image.GetLargestPossibleRegion().GetSize()  # don't touch this


brightvessels = False                  # Bright=ToF, Dark=SWI,T2star
sigma_minimum = np.mean(voxel_size)    # min voxel size (or mean), can be a little higher
sigma_maximum = sigma_minimum*1.95     # a bit less than twice
number_of_sigma_steps=15               #The more the better, between 10 and 15 is good
fact_version = True                    #Factory analysis (new) vs Max pooling (classic)

alpha=0.8                     #In theory, no need to adjust this
beta=1.0                      #In theory, no need to adjust this
gamma=100                     #50 #Play with this to ajust for the noise/blur in the vascular segmentation
scaleoutput = True            #Do. Not. Touch. That.

otsu_offset=0.3              #Between -1.0 and 1.0 -> higher = more vessels when thresholded
    
MIP_half_thickness=8           # Half-thickness of the MIP view
window_half_size=50            # 30 = 60x60 zoomed window
mip_dir="ax" #"sag"  #"cor"    # choice between "sag" "cor" or other (default)

# Center of the window
ix, iy, iz  = np.int(size[0]/1.6), np.int(size[1]/1.5), np.int(size[2]/2.)


# Load images and print them
mask_img   = image.load_img(input_mask_url)
mask       = mask_img.get_data()
input_img  = image.load_img(input_img_url)
voxel_size = mask_img.header.get_zooms()[0:2]

print("Voxel size is: ", voxel_size)
print("Image dim is: ", mask_img.shape)
#print("Image dim is: ", size)

if mip_dir == 'sag':
    sx, sy, sz  = MIP_half_thickness, window_half_size, window_half_size
    mip_axis = 0

    dx, ux = np.int(ix-sx), np.int(ix+sx)
    dy, uy = np.int(iy-sy), np.int(iy+sy)  
    dz, uz = np.int(iz-sz), np.int(iz+sz) 

    mip_cube = tuple([ slice(dx,ux,1), slice(dy,uy,1),   slice(dz,uz,1)    ])
    mip_view = tuple([ slice(dx,ux,1), slice(None,None), slice(None,None)  ])
    sng_cube = tuple([ ix,             slice(dy,uy,1),   slice(dz,uz,1)    ])
    sng_view = tuple([ ix,             slice(None,None), slice(None,None)  ])
elif mip_dir == 'cor':
    sx, sy, sz  = window_half_size, MIP_half_thickness, window_half_size
    mip_axis = 1

    dx, ux = np.int(ix-sx), np.int(ix+sx)
    dy, uy = np.int(iy-sy), np.int(iy+sy)  
    dz, uz = np.int(iz-sz), np.int(iz+sz) 

    mip_cube = tuple([ slice(dx,ux,1),   slice(dy,uy,1), slice(dz,uz,1)    ])
    mip_view = tuple([ slice(None,None), slice(dy,uy,1), slice(None,None)  ])
    sng_cube = tuple([ slice(dx,ux,1),   iy,             slice(dz,uz,1)    ])
    sng_view = tuple([ slice(None,None), iy,             slice(None,None)  ])

else :
    sx, sy, sz  = window_half_size, window_half_size, MIP_half_thickness
    mip_axis = 2

    dx, ux = np.int(ix-sx), np.int(ix+sx)
    dy, uy = np.int(iy-sy), np.int(iy+sy)  
    dz, uz = np.int(iz-sz), np.int(iz+sz)  

    mip_cube = tuple([ slice(dx,ux,1),    slice(dy,uy,1),    slice(dz,uz,1) ])
    mip_view = tuple([ slice(None, None), slice(None, None), slice(dz,uz,1) ])
    sng_cube = tuple([ slice(dx,ux,1),    slice(dy,uy,1),    iz             ])
    sng_view = tuple([ slice(None, None), slice(None, None), iz             ])
    
#Set-up window variables
print("Center: ",      ix, iy, iz)
print("Zoom window: ", sx, sy, sz)
print("Low limit: ",   dx, dy, dz)
print("High limit: ",  ux, uy, uz)
print("Size of single will be: ",        mask[sng_view].T.shape)
print("Size of zoomed single will be: ", mask[sng_cube].T.shape)
print("Size of MIP will be: ",           np.min(mask[mip_view], axis=mip_axis).shape)
print("Size of zoomed MIP will be: ",    np.min(mask[mip_cube], axis=mip_axis).shape)
   
#Plot the loaded images
fig, axes = plt.subplots(1, 3, figsize=(28, 20))
fig.tight_layout()
ax = axes.flatten()
ax[0].imshow(input_img.get_data()[sng_view].T, origin='lower', cmap=plt.cm.gray)
ax[1].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
ax[2].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
plt.title("Data loaded")
plt.savefig("Loaded_data_fig.png", bbox_inches='tight')
plt.show()



def FA_combine_scales(img_4D, mask, sigma_min, sigma_max):
    masker = NiftiMasker()
    data_masked = masker.fit_transform(img_4D)
    print(np.max(data_masked))
    assert np.max(data_masked) > 0, "Mask and data are not in same space/orient"

    
    p0, p99 = np.percentile(data_masked, (0, 99.8)) 
    data_masked = exposure.rescale_intensity(data_masked, in_range=(p0,p99))

    #print(data_masked.shape)

    print(np.max(data_masked))

    noise_variance_init=np.geomspace(sigma_min, sigma_max, data_masked.shape[0]) - sigma_min #[0:16]
    print("Noise variance: ", noise_variance_init)

    method = FactorAnalysis(n_components=1, 
                            noise_variance_init=noise_variance_init,
                            #noise_variance_init=np.geomspace(sigma_minimum, sigma_maximum, number_of_sigma_steps),
                            tol=0.005, max_iter=200)
    method.fit(data_masked.T)
    X_reduced = method.transform(data_masked.T)
    X_reduced[X_reduced<0] = 0
    #print(X_reduced.shape)

    #print(X_reduced.noise_variance_array_)
    fact_img = masker.inverse_transform(X_reduced.T)
    fact_img = image.index_img(fact_img,0)
    
    factimg=nilearn.image.math_img("np.where(x<0, 0, x)", x=fact_img)
    
    p0, p99 = np.percentile(factimg.get_data(), (0, 99.8)) 
    factdata = exposure.rescale_intensity(factimg.get_data(), in_range=(p0,p99))

    return factdata

def skeletonize_mask(i_thresh_img, i_voxel_size=0.6, i_factor=2.0):
    thresh_data = i_thresh_img.get_data()
    dil_thresh = binary_dilation(thresh_data)
    skel = morphology.skeletonize_3d(dil_thresh)
    #skel = image.load_img(outfold + output_prefix + "_VED_skel.nii.gz").get_data()
    skel[skel > 0] = 1.0

    prox_sampling = int(i_voxel_size/i_factor)

    #up_thresh = rescale(thresh_data, i_factor)
    #up_distance = ndi.distance_transform_edt(up_thresh, sampling=prox_sampling)
    #distance = resize(up_distance, skel.shape) * thresh_data
    distance = ndi.distance_transform_edt(thresh_data, sampling=i_voxel_size)
    centerdia = skel * distance

    #up_i_thresh = 1.0 - up_thresh
    up_i_thresh = 1.0 - thresh_data
    #up_i_distance = ndi.distance_transform_edt(up_i_thresh, sampling=prox_sampling) #get from header instead
    up_i_distance = ndi.distance_transform_edt(up_i_thresh, sampling=i_voxel_size) #get from header instead
    inverted_distance = up_i_distance  * up_i_thresh

    return skel, centerdia, distance, inverted_distance

def otsu_3D(i_img, offset=0.2):

    p0, p99 = np.percentile(i_img.get_data(), (0, 99.8)) 
    datacopy = exposure.rescale_intensity(i_img.get_data(), in_range=(p0,p99))

    #datacopy = fact_img.get_data()
    threshold_x = np.zeros_like(datacopy)
    threshold_y = np.zeros_like(datacopy)
    threshold_z = np.zeros_like(datacopy)

    offset = 0.2
    #threshold_local
    for x in range(datacopy.shape[0]):
        if np.max(datacopy[x, :, :]) == 0.0:
            val = 1.0
        else:
            img_adapteq = exposure.equalize_adapthist(datacopy[x,:,:], clip_limit=0.01)
            val = threshold_otsu(img_adapteq, nbins=400 )
            val = max(0.0, val-offset)
            threshold_x[x, :, :] = val

    for y in range(datacopy.shape[1]):
        if np.max(datacopy[:, y, :]) == 0.0:
            val = 1.0
        else:
            img_adapteq = exposure.equalize_adapthist(datacopy[:,y,:], clip_limit=0.01)
            val = threshold_otsu(img_adapteq, nbins=400 )
            val = max(0.0, val-offset)
            threshold_y[:, y, :] = val

    for z in range(datacopy.shape[2]):
        if np.max(datacopy[:, :, z]) == 0.0:
            val = 1.0
        else:
            img_adapteq = exposure.equalize_adapthist(datacopy[:,:,z], clip_limit=0.01)
            val = threshold_otsu(img_adapteq, nbins=400 )
            val = max(0.0, val-offset)
            threshold_z[:, :, z] = val

    threshold = np.stack((threshold_x, threshold_y, threshold_z), axis=3)
    threshold_min = np.min(threshold, axis=3)

    return datacopy > threshold_min

print("+-+-+- DOING MULTI-SCALE METHOD")

ImageType = type(in_image)
Dimension = in_image.GetImageDimension()  #would be 3

HessianPixelType = itk.SymmetricSecondRankTensor[itk.D, Dimension]
HessianImageType = itk.Image[HessianPixelType, Dimension]
print(HessianImageType)

objectness_filter = itk.HessianToObjectnessMeasureImageFilter[HessianImageType, ImageType].New()
objectness_filter.SetBrightObject(brightvessels)
objectness_filter.SetScaleObjectnessMeasure(scaleoutput)  #WAS FALSE per default
objectness_filter.SetAlpha(alpha) # 0.5 default
objectness_filter.SetBeta(beta)   # should be 1.0
objectness_filter.SetGamma(gamma) # was 5.0 in example, 300 at high gamma.. higher = less noise


outfold = output_prefix + '_scales/'
if not os.path.exists(outfold):
    os.makedirs(outfold)


print("sigma_minimum is: " + str(sigma_minimum))
print("sigma_maximum is: " + str(sigma_maximum))
for step, sigma in enumerate(np.geomspace(sigma_minimum, sigma_maximum, number_of_sigma_steps)):

    scale_str = str(int(sigma*100000)).zfill(8)
    
    out_scale_prefix=outfold + output_prefix + "_" + scale_str

    if not os.path.exists(out_scale_prefix + "_VED.nii.gz"):
        print(" +-+- DOING sigma scale " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
        multi_scale_filter = itk.MultiScaleHessianBasedMeasureImageFilter[ImageType, HessianImageType, ImageType].New()
        multi_scale_filter.SetInput(in_image)
        multi_scale_filter.SetHessianToMeasureFilter(objectness_filter)
        multi_scale_filter.SetSigmaStepMethodToLogarithmic()
        multi_scale_filter.SetSigmaMinimum(sigma)
        multi_scale_filter.SetSigmaMaximum(sigma)
        multi_scale_filter.SetNumberOfSigmaSteps(1)
        multi_scale_filter.SetGenerateHessianOutput(True)
        multi_scale_filter.Update()

        print("  +- ..write VED of sigma " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
        Hessian_output=multi_scale_filter.GetHessianOutput()
        VED_output=multi_scale_filter.GetOutput()
        size=Hessian_output.GetBufferedRegion().GetSize()

        hessian_view = itk.GetArrayViewFromImage(Hessian_output)
        hessian_array = np.transpose(hessian_view, (2,1,0,3))
        ved_view = itk.GetArrayViewFromImage(VED_output)
        ved_array = np.transpose(ved_view, (2,1,0))
        ved_array = ved_array*mask

        #print(np.min(ved_array))
        #print(np.max(ved_array))
        p0, p99 = np.percentile(ved_array, (0, 99.8)) 
        ved_array = exposure.rescale_intensity(ved_array, in_range=(p0,p99))
        #print(np.min(ved_array))
        #print(np.max(ved_array))
    

        ved_img = nilearn.image.new_img_like(input_img, ved_array, copy_header=True)
        ved_img.to_filename(out_scale_prefix + "_VED.nii.gz")

        fig, axes = plt.subplots(1, 3, figsize=(28, 14))
        fig.tight_layout()
        ax = axes.flatten()
        ax[0].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower')
        ax[1].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
        ax[1].imshow(np.max(     ved_array[mip_view], axis=mip_axis).T, origin='lower', cmap='viridis', alpha=0.7)
        ax[2].imshow(np.max(     ved_array[mip_view], axis=mip_axis).T, origin='lower')
        plt.title("VED for: " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
        plt.savefig(out_scale_prefix + '_fig.png', bbox_inches='tight')
        plt.show()       
        
        
        fig, axes = plt.subplots(1, 3, figsize=(28, 14))
        fig.tight_layout()
        ax = axes.flatten()
        ax[0].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower')
        ax[1].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
        ax[1].imshow(np.max(     ved_array[mip_cube], axis=mip_axis).T, origin='lower', cmap='viridis', alpha=0.7)
        ax[2].imshow(np.max(     ved_array[mip_cube], axis=mip_axis).T, origin='lower')
        plt.title("VED for: " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
        plt.savefig(out_scale_prefix + '_ZOOM.png', bbox_inches='tight')
        plt.show()
        #plotting.plot_epi(ved_img, input_img, crop=True, filename=out_scale_prefix + "_VED.png", dpi=150)

        

out_classic_prefix=output_prefix + "_CLASSIC"

if not os.path.exists(out_classic_prefix + "_VED.nii.gz"):
   
    print("+-+-+- DOING CLASSIC METHOD")
    multi_scale_filter = itk.MultiScaleHessianBasedMeasureImageFilter[ImageType, HessianImageType, ImageType].New()
    multi_scale_filter.SetInput(in_image)
    multi_scale_filter.SetHessianToMeasureFilter(objectness_filter)
    multi_scale_filter.SetSigmaStepMethodToLogarithmic()
    multi_scale_filter.SetSigmaMinimum(sigma_minimum)
    multi_scale_filter.SetSigmaMaximum(sigma_maximum)
    multi_scale_filter.SetNumberOfSigmaSteps(number_of_sigma_steps)
    multi_scale_filter.SetGenerateHessianOutput(True)
    multi_scale_filter.Update()
    print("+-+-+- DONE CLASSIC")

    print("+-+- Check output size")
    Hessian_output=multi_scale_filter.GetHessianOutput()
    VED_output=multi_scale_filter.GetOutput()
    size=Hessian_output.GetBufferedRegion().GetSize()
    print("%d,%d,%d"%(size[0],size[1],size[2]))


    print("+-+- Making new image")
    hessian_view = itk.GetArrayViewFromImage(Hessian_output)
    hessian_array = np.transpose(hessian_view, (2,1,0,3))
    ved_view = itk.GetArrayViewFromImage(VED_output)
    ved_array = np.transpose(ved_view, (2,1,0))
    print(hessian_array.shape)

    ved_array = np.transpose(ved_view, (2,1,0))
    ved_array = ved_array*mask

    print(np.min(ved_array))
    print(np.max(ved_array))
    p0, p99 = np.percentile(ved_array, (0, 99.8)) 
    ved_array = exposure.rescale_intensity(ved_array, in_range=(p0,p99))
    print(np.min(ved_array))
    print(np.max(ved_array))

    ved_img = nilearn.image.new_img_like(input_img, ved_array, copy_header=True)
    ved_img.to_filename(out_classic_prefix + "_VED.nii.gz")

    ####
    # Plot: MIP around shape[1]/2
    ###
    fig, axes = plt.subplots(1, 3, figsize=(28, 14))
    fig.tight_layout()
    ax = axes.flatten()
    ax[0].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower')
    ax[1].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
    ax[1].imshow(np.max(     ved_array[mip_view], axis=mip_axis).T, origin='lower', cmap='viridis', alpha=0.7)
    ax[2].imshow(np.max(     ved_array[mip_view], axis=mip_axis).T, origin='lower')
    plt.title("VED for: " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
    plt.savefig(out_classic_prefix + '_fig.png', bbox_inches='tight')
    plt.show()
    
            
    fig, axes = plt.subplots(1, 3, figsize=(28, 14))
    fig.tight_layout()
    ax = axes.flatten()
    ax[0].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower')
    ax[1].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower', cmap=plt.cm.gray, alpha=1.0)
    ax[1].imshow(np.max(     ved_array[mip_cube], axis=mip_axis).T, origin='lower', cmap='viridis', alpha=0.7)
    ax[2].imshow(np.max(     ved_array[mip_cube], axis=mip_axis).T, origin='lower')
    plt.title("VED for: " + str(sigma) + ": " + str(step+1) + " of " + str(number_of_sigma_steps))
    plt.savefig(out_classic_prefix + '_ZOOM.png', bbox_inches='tight')
    plt.show()

    hessian_img = nilearn.image.new_img_like(input_img, hessian_array, copy_header=True)
    hessian_img.to_filename(out_classic_prefix + "_best_hessian.nii.gz")

    ved_img = nilearn.image.new_img_like(input_img, ved_array, copy_header=True)
    ved_img.to_filename(out_classic_prefix + "_VED.nii.gz")

else:
    ved_img =       nilearn.image.load_img(out_classic_prefix + "_VED.nii.gz")
    hessian_img =   nilearn.image.load_img(out_classic_prefix + "_best_hessian.nii.gz")
    ved_array =     ved_img.get_data()
    hessian_array = hessian_img.get_data()
    size = ved_array.shape

print("+-+-+- Load VED images")
VED_images=nilearn.image.load_img(outfold + output_prefix + "*_VED.nii.gz", wildcards=True)
print("Shape is: ", VED_images.shape)

if not os.path.exists(outfold + output_prefix + "_VED_max.nii.gz"):
    print("+-+- Compute max image")
    max_VED_images=nilearn.image.math_img("np.max(a, axis=3)", a=VED_images)
    max_VED_images.to_filename(outfold + output_prefix + "_VED_max.nii.gz")
else:
    max_VED_images = image.load_img(outfold + output_prefix + "_VED_max.nii.gz")

if not os.path.exists(outfold + output_prefix + "_VED_FactorAnalysis.nii.gz"):
    print("+-+- Compute factory analysis")
    fact_data = FA_combine_scales(VED_images, mask_img, sigma_minimum, sigma_maximum)
    fact_img = image.new_img_like(input_img, fact_data)
    fact_img.to_filename(outfold + output_prefix + "_VED_FactorAnalysis.nii.gz") 
else:
    fact_img=image.load_img(outfold + output_prefix + "_VED_FactorAnalysis.nii.gz")
    
if fact_version:
    ved_img = fact_img
else:
    ved_img = max_VED_images

if not os.path.exists(outfold + output_prefix + "_VED_otsu_thresh.nii.gz"):
    print("+-+- Compute OTSU 3D analysis")
    thr_data = otsu_3D(ved_img, offset=otsu_offset)
    otsu_tresh_img = image.new_img_like(ved_img, thr_data)
    otsu_tresh_img.to_filename(outfold + output_prefix + "_VED_otsu_thresh.nii.gz")
else:
    otsu_tresh_img=image.load_img(outfold + output_prefix + "_VED_otsu_thresh.nii.gz")

    
thresh_img = otsu_tresh_img
ved_data = fact_img.get_data()
thresh_data = otsu_tresh_img.get_data()

#Print results

fig, axes = plt.subplots(1, 3, figsize=(28, 12))
fig.tight_layout()
ax = axes.flatten()
ax[0].imshow(input_img.get_data()[sng_view].T, origin='lower')
ax[1].imshow(ved_img.get_data()[sng_view].T, origin='lower', cmap='viridis')
ax[2].imshow(otsu_tresh_img.get_data()[sng_view].T, origin='lower', cmap='viridis')
plt.savefig(outfold + output_prefix + '_VED_processed.png', bbox_inches='tight')
plt.show()

fig, axes = plt.subplots(1, 3, figsize=(28, 12))
fig.tight_layout()
ax = axes.flatten()
ax[0].imshow(np.min(input_img.get_data()[mip_view], axis=mip_axis).T, origin='lower')
ax[1].imshow(np.max(ved_img.get_data()[mip_view], axis=mip_axis).T, origin='lower', cmap='viridis')
ax[2].imshow(np.max(otsu_tresh_img.get_data()[mip_view], axis=mip_axis).T, origin='lower', cmap='viridis')
plt.savefig(outfold + output_prefix + '_VED_processed_MIP.png', bbox_inches='tight')
plt.show()

fig, axes = plt.subplots(1, 3, figsize=(28, 12))
fig.tight_layout()
ax = axes.flatten()
ax[0].imshow(np.min(input_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower')
ax[1].imshow(np.max(ved_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower', cmap='viridis')
ax[2].imshow(np.max(otsu_tresh_img.get_data()[mip_cube], axis=mip_axis).T, origin='lower', cmap='viridis')
plt.savefig(outfold + output_prefix + '_VED_processed_MIP_ZOOM.png', bbox_inches='tight')
plt.show()

distutils Version classes are deprecated. Use packaging.version instead.
Setuptools is replacing distutils.
