In [1]:
# Import dependencies
import os
import copy
import numpy as np
from PIL import Image
from scipy import ndimage
from math import floor, sqrt
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

In [2]:
def scanCube(dx, dy, dz, distSqValues):
    numRadii = distSqValues.size
    r1Sq = np.full(numRadii, 0)
    if ((dx == 0) and (dy == 0) and (dz == 0)):
        for rSq in range(numRadii):
            r1Sq[rSq] = 2147483647 #Integer.MAX_VALUE
    else:
        dxAbs = -abs(dx)
        dyAbs = -abs(dy)
        dzAbs = -abs(dz)
        for rSqInd in range(numRadii):
            rSq = distSqValues[rSqInd]
            maxi = 0
            r = 1 + floor(sqrt(rSq))
            scank = None
            scankj = None
            dk = None
            dkji = None
            iBall = None
            iPlus = None
            for k in range(r+1):
                scank = k * k
                dk = (k - dzAbs) * (k - dzAbs)
                for j in range(r+1):
                    scankj = scank + j * j
                    if (scankj <= rSq):
                        iPlus = floor(sqrt(rSq - scankj)) - dxAbs
                        dkji = dk + (j - dyAbs) * (j - dyAbs) + iPlus * iPlus
                        if (dkji > maxi):
                            maxi = dkji
            r1Sq[rSqInd] = maxi
    return r1Sq

In [3]:
def createTemplate(distSqValues):

    t0 = np.array(scanCube(1, 0, 0, distSqValues))
    t1 = np.array(scanCube(1, 1, 0, distSqValues))
    t2 = np.array(scanCube(1, 1, 1, distSqValues))
    
    t = [t0,t1,t2]
    return t

In [4]:
def stripExtension(name):
    if(len(name) == 0):
        tempTuple = os.path.splitext(name)
    return tempTuple[0]

In [13]:
def distanceRidges(dist_img):
    resultImage = np.zeros(dist_img.shape,np.float64)

    w = dist_img.shape[0]
    h = dist_img.shape[1]
    d = 1

    sNew = np.full(len(dist_img.ravel()),0)
    
    s = dist_img.ravel()
    
    k1 = None
    j1 = None
    i1 = None
    dz = None
    dy = None
    dx = None
    notRidgePoint = None
    sk1 = []
    sk = []
    skNew = []
    sk0Sq = None
    sk0SqInd = None
    sk1Sq = None

    # Find the largest distance in the data
    distMax = 0.0
    for k in range(d):
        sk = s
        for j in range(h):
            for i in range(w):
                ind = i + w * j
                if (sk[ind] > distMax):
                    distMax = sk[ind]

    rSqMax = floor((distMax * distMax + 0.5) + 1)
    occurs = np.full(rSqMax, False)
    
    for k in range(d):
        sk = s
        for j in range(h):
            for i in range(w):
                ind = i + w * j
                occurs[floor(sk[ind] * sk[ind] + 0.5)] = True
    numRadii = 0
    for i in range(rSqMax):
        if (occurs[i]):
            numRadii+=1

    # Make an index of the distance-squared values
    distSqIndex = np.full(rSqMax, 0)
    distSqValues = np.full(numRadii, 0)
    indDS = 0
    for i in range(rSqMax):
        if (occurs[i]):
            distSqIndex[i] = indDS
            distSqValues[indDS] = i
            indDS+=1
    
    # Build template
    rSqTemplate = createTemplate(distSqValues)
    numCompZ = None
    numCompY = None
    numCompX = None
    numComp = None
    for k in range(d):
        sk = s
        skNew = sNew
        for j in range(h):
            for i in range(w):
                ind = i + w * j
                if (sk[ind] > 0):
                    notRidgePoint = False
                    sk0Sq = floor(sk[ind] * sk[ind] + 0.5)
                    sk0SqInd = distSqIndex[sk0Sq]
                    for dz in range(-1,2):
                        k1 = k + dz
                        if ((k1 >= 0) and (k1 < d)):
                            sk1 = s
                            if (dz == 0):
                                numCompZ = 0
                            else:
                                numCompZ = 1
                            for dy in range(-1,2):
                                j1 = j + dy
                                if ((j1 >= 0) and (j1 < h)):
                                    if (dy == 0):
                                        numCompY = 0
                                    else:
                                        numCompY = 1
                                    for dx in range(-1,2):
                                        i1 = i + dx
                                        if ((i1 >= 0) and (i1 < w)):
                                            if (dx == 0):
                                                numCompX = 0
                                            else:
                                                numCompX = 1
                                            numComp = numCompX + numCompY + numCompZ
                                            if (numComp > 0):
                                                sk1Sq = floor(sk1[i1 + w * j1] * sk1[i1 + w * j1] + 0.5)
                                                if (sk1Sq >= rSqTemplate[numComp - 1][sk0SqInd]):
                                                    notRidgePoint = True
                                        # if in grid for i1
                                        if (notRidgePoint):
                                            break
                                    # dx
                                # if in grid for j1
                                if (notRidgePoint):
                                    break
                            # dy
                        # if in grid for k1
                        if (notRidgePoint):
                            break
                    # dz
                    if (not notRidgePoint):
                        skNew[ind] = sk[ind]
                # if not in background
            # i
        # j
    # k

    resultImage = skNew.reshape(w, h)
    return resultImage

In [11]:
def localThickness(dist_ridges):
    resultImage = np.zeros(dist_ridges.shape,np.float64)
    w = dist_ridges.shape[0]
    h = dist_ridges.shape[1]
    d = 1
    s = dist_ridges.ravel()

    # Count the distance ridge points on each slice
    nRidge = None
    ind = None
    nr = None
    iR = None
    sk = s
    nr = 0
    for j in range(h):
        for i in range(w):
            ind = i + w * j
            if (sk[ind] > 0):
                nr+=1
    nRidge = nr

    iRidge = []
    jRidge = []
    rRidge = []
    sMax = 0
    iRidge = np.full(nr,0)
    jRidge = np.full(nr,0)
    rRidge = np.full(nr,0)
    iRidgeK = iRidge
    jRidgeK = jRidge
    rRidgeK = rRidge
    iR = 0

    for j in range(h):
        for i in range(w):
            ind = i + w * j
            if (sk[ind] > 0):
                iRidgeK[iR] = i
                jRidgeK[iR] = j
                rRidgeK[iR] = sk[ind]
                iR+=1
                if (sk[ind] > sMax):
                    sMax = sk[ind]
                sk[ind] = 0

    nThreads = 1

    i = None
    j = None
    sk = None
    sk1 = None
    # Loop through ridge points. For each one, update the local thickness for
    # the points within its sphere.
    r = None
    rInt = None
    ind1 = None
    iStart = None
    iStop = None
    jStart = None
    jStop = None
    kStart = None
    kStop = None
    r1SquaredK = None
    r1SquaredJK = None
    r1Squared = None
    s1 = None
    rSquared = None

    nR = nRidge
    iRidgeK = iRidge
    jRidgeK = jRidge
    rRidgeK = rRidge

    for iR in range(nR):
        i = iRidgeK[iR]
        j = jRidgeK[iR]
        r = rRidgeK[iR]
        rSquared = floor(r * r + 0.5)
        rInt = floor(r)
        if (rInt < r):
            rInt+=1
        iStart = i - rInt
        if (iStart < 0):
            iStart = 0
        iStop = i + rInt
        if (iStop >= w):
            iStop = w - 1
        jStart = j - rInt
        if (jStart < 0):
            jStart = 0
        jStop = j + rInt
        if (jStop >= h):
            jStop = h - 1
        kStart = 0 - rInt
        if (kStart < 0):
            kStart = 0
        kStop = 0 + rInt
        if (kStop >= d):
            kStop = d - 1
        r1SquaredK = 0
        sk1 = s
        for j1 in range(jStart, jStop+1):
            r1SquaredJK = r1SquaredK + (j1 - j) * (j1 - j)
            if (r1SquaredJK <= rSquared):
                for i1 in range(iStart, iStop+1):
                    r1Squared = r1SquaredJK + (i1 - i) * (i1 - i)
                    if (r1Squared <= rSquared):
                        ind1 = i1 + w * j1
                        s1 = sk1[ind1]
                        if (rSquared > s1):
                            s1 = sk1[ind1]
                            if (rSquared > s1):
                                sk1[ind1] = rSquared
                    # if within shere of DR point
                # i1
            # if k and j components within sphere of DR point
        # j1
    # iR

    # Fix the square values and apply factor of 2
    sk = s
    for j in range(h):
        for i in range(w):
            ind = i + w * j
            sk[ind] = 2 * sqrt(sk[ind])

    print("Local Thickness complete")
    resultImage = sk.reshape(w, h)
    return resultImage

In [None]:
def main():
    # load test image
    img = Image.open('test_ridges_img.png')
    
    test_img = np.array(img)
    # make sure input numpy ndarray has correct format
    if len(test_img.shape)<3:
        test_img = test_img.reshape((test_img.shape[0], test_img.shape[1], 1))
    elif len(test_img.shape)<2 or len(test_img.shape)>3:
        raise Exception("Wrong input image dimension. Make sure input is 2D or 3D.")
    binary_test = test_img<250
    binary_test = np.multiply(binary_test,test_img>0)
    # distance map (transform)
    dist_img = ndimage.distance_transform_edt(binary_test)
    # distance ridges
    dist_ridges = distanceRidges(dist_img)
    # local thickness
    local_thickness = localThickness(copy.deepcopy(dist_ridges))
     
    # visualization
    cmap = np.loadtxt('LUT_fire.csv',delimiter=',')# load csv of colormap
    # construct cmap
    my_cmap = ListedColormap(cmap)
    fig3, axs = plt.subplots(4, 1, constrained_layout=True,figsize=(15,30))
    axs[0].imshow(test_img, cmap='gray')
    axs[0].set_title('test image')
    axs[1].imshow(dist_img, cmap=my_cmap)
    axs[1].set_title('distance image masked')
    axs[2].imshow(dist_ridges, cmap=my_cmap)
    axs[2].set_title('distance ridges')
    axs[3].imshow(local_thickness, cmap=my_cmap)
    axs[3].set_title('local thickness')
    plt.show()
    
main()