In [None]:
import matplotlib.pyplot as plt

from skimage.color import rgb2gray, rgba2rgb
import skimage.io as skio
from skimage.util.shape import view_as_windows
from skimage.transform import resize

import numpy as np
import pandas as pd
from scipy import stats  
from einops import rearrange, reduce, repeat

import itertools
import os
import glob
import warnings
import tqdm

In [None]:
class CurvRectValues:  
    def __init__(self):
        self.files           = []
        self.image_size      = 128 
        self.randstate       = 49
        self.file_block_size = 20
        self.images          = []
        self.fft_images      = []
        self.cg_parameters   = {'kA':0.3,'bend':0.01,'orientation':45*np.pi/180}
        self.curv_values     = []
        self.rect_values     = []
        self.curv_max        = []
        self.rect_max        = []
        self.curv_unique     = []
        self.rect_unique     = []
        self.__generate_kernels__()
        
    
    def __bananakernel__(self):
        """
        input: cg_parameters is a dictionary, including kA, bA, alphaA, mA, sigmaXbend, sigmaYbend, 
                                               xA_half, yA_half, xA_shift, yA_shift
                                               
        the function is used to generate banana wavelet kernels.  The kernels
        can be used to filter a image to quantify curvatures.

        kA:          length of the wave vector K
        bA:          bending value b
        alphaA:      direction of the wave vector
        mA:          magnitude value m
        xA_half:     x-size
        yA_half:     y-size
        xA_shift:    center shift in x direction
        yA_shift:    center shift in y direction

        for references:
        preFactorA:  pre-factor p
        DCPartRealA: real dc-part
        DCPartImagA: imaginary dc-part
        gaussPartA:  Gaussian part    

        filter requency: (cycle/object) = xA*kA/(2*pi*mA)
        kernel size: 2*4*sigmaYbend*mA*(1/kA)

        return SpaceKernel, FreKernel

        last updated 5/23/2021
        last updated 3/23/3021
        """
        kA         = self.cg_parameters['kA']
        bA         = self.cg_parameters['bend']    
        alphaA     = self.cg_parameters.get('orientation',45*np.pi/180)
        mA         = self.cg_parameters.get('mA',3)
        sigmaXbend = self.cg_parameters.get('sigmaXbend',2)
        sigmaYbend = self.cg_parameters.get('sigmaYbend',2)
        xA_half    = self.cg_parameters.get('xA_half',self.image_size/2)
        yA_half    = self.cg_parameters.get('yA_half',self.image_size/2)
        xA_shift   = self.cg_parameters.get('x_shift',0)
        yA_shift   = self.cg_parameters.get('y_shift',0)
        
        if isinstance(bA, complex):
            print('bA has to be real number. However your input is a complex number')
            bA = np.real(bA)

        if any(x<=0 for x in np.array([kA, mA])) or any(np.isnan(np.array([kA, bA, alphaA, mA]))):     
            out_ranage_value = 10**-20
            SpaceKernel = np.ones((2*xA_half,2*yA_half))*out_ranage_value
            FreKernel   = np.ones((2*xA_half,2*yA_half))*out_ranage_value
            return SpaceKernel, FreKernel

        kernel_size = 2*xA_half
        if kernel_size%2 !=0:
            kernel_size = kernel_size + 1
        [xA, yA] = np.meshgrid(np.arange(-kernel_size/2, kernel_size/2,1),np.arange(-kernel_size/2, kernel_size/2,1)) 
        xA = xA - xA_shift
        yA = yA - yA_shift

        xRotL = np.cos(alphaA)*xA + np.sin(alphaA)*yA 
        yRotL = np.cos(alphaA)*yA - np.sin(alphaA)*xA

        xRotBendL = xRotL + bA * (yRotL)**2
        yRotBendL = yRotL

        """make the DC free""" 
        tmpgaussPartA = np.exp(-0.5*(kA)**2*((xRotBendL/sigmaXbend)**2 + (yRotBendL/(mA*sigmaYbend))**2))
        tmprealteilL  = 1*tmpgaussPartA*(np.cos(kA*xRotBendL) - 0)
        tmpimagteilL  = 1*tmpgaussPartA*(np.sin(kA*xRotBendL) - 0)

        numeratorRealL = np.sum(tmprealteilL)
        numeratorImagL = np.sum(tmpimagteilL)
        denominatorL   = np.sum(tmpgaussPartA)

        DCValueAnalysis = np.exp(-0.5 * sigmaXbend * sigmaXbend)
        if denominatorL==0:
            DCPartRealA = DCValueAnalysis
            DCPartImagA = 0
        else:    
            DCPartRealA = numeratorRealL/denominatorL
            DCPartImagA = numeratorImagL/denominatorL
            if DCPartRealA < DCValueAnalysis:
                DCPartRealA = DCValueAnalysis
                DCPartImagA = 0

        """generate a space kernel""" 
        preFactorA = kA**2
        gaussPartA = np.exp(-0.5*(kA)**2*((xRotBendL/sigmaXbend)**2 + (yRotBendL/(mA*sigmaYbend))**2))
        realteilL  = preFactorA*gaussPartA*(np.cos(kA*xRotBendL) - DCPartRealA)
        imagteilL  = preFactorA*gaussPartA*(np.sin(kA*xRotBendL) - DCPartImagA)

        """normalize the kernel"""  
        normRealL   = np.sqrt(np.sum(realteilL**2))
        normImagL   = np.sqrt(np.sum(imagteilL**2))
        normFactorL = kA**2

        total_std = normRealL + normImagL
        if total_std == 0:
            total_std = 10**20
        norm_realteilL = realteilL*normFactorL/(0.5*total_std)
        norm_imagteilL = imagteilL*normFactorL/(0.5*total_std)
        
        space_kernel = norm_realteilL + norm_imagteilL*1j
        freq_kernel = np.fft.ifft2(space_kernel)
        return space_kernel, freq_kernel
                                  
        
    def __generate_kernels__(self, kA_scales=[1,5,8]):
        all_kA = [2*np.pi/((np.sqrt(2))**x) for x in kA_scales]
        bends = [0, 0.02,0.07,0.10,0.18,0.45]
        alphaA = np.linspace(0,2*np.pi,8).tolist()
        
        curv_freq_kernels,rect_freq_kernels,curv_space,rect_space = [],[],[],[]
        for kA, bA, orien in itertools.product(all_kA, bends, alphaA):
            self.cg_parameters['kA']          = kA
            self.cg_parameters['bend']        = bA/8
            self.cg_parameters['orientation'] = orien
            neuron, Freq_kernel = self.__bananakernel__()
            if bA == 0:
                rect_freq_kernels.append(Freq_kernel)
                rect_space.append(neuron.real) 
            else:
                curv_freq_kernels.append(Freq_kernel)
                curv_space.append(neuron.real)

        self.kernels = {'curv_freq':curv_freq_kernels, 'curv_space':curv_space,
                        'rect_freq':rect_freq_kernels, 'rect_space':rect_space}

        
    
    def __patchnorm__(self,image):
        """ make sure it is gray scale image in range from 0 - 255 """ 
        if np.max(image) <=1:
            image = 255*image/np.max(image)
        orig_size = image.shape
        
        if image.shape[0]%3 == 0:
            patch_size = 3
        else:
            patch_size = 4
        self.patch_size = patch_size
        
        """create patches with the patch_size"""
        patches = view_as_windows(image, (patch_size,patch_size), patch_size)

        """ caculate norm of the local patches """ 
        local_norm = np.sqrt(np.einsum('ijkl->ij',patches**2))
        local_norm[local_norm<1] = 1

        """normalize local patches """ 
        normed_patches = patches/local_norm[:,:,np.newaxis,np.newaxis]

        """reshape normalized local patch to original shape """ 
#         local_normed_image = normed_patches.transpose(0,2,1,3).reshape(-1,normed_patches.shape[1]*normed_patches.shape[3])
        local_normed_image = rearrange(normed_patches,'h w c d -> (h c) (w d)')
        
        return {'local_norm':local_norm, 'local_normed_image':local_normed_image, 
                'total_local_norm':np.sqrt(local_norm.sum())}

    
    def processing_images(self, files):
        images_list, fft_images_list = [],[]
        self.files = files
        self.curv_values     = []
        self.rect_values     = []
        self.curv_max        = []
        self.rect_max        = []
        self.curv_unique     = []
        self.rect_unique     = []    
        folder_name = os.path.dirname(files[0])
        print(f'processing {len(files)} images...')
        for i in tqdm.tqdm(range(0, len(files), self.file_block_size)):
            block_files   = files[i:i + self.file_block_size]
            _, fft_images = self.__read_images__(block_files)
            self.__calcuate_curv_rect_values__(fft_images)
            
        self.curv_max    = np.dstack(self.curv_max)
        self.curv_unique = np.dstack(self.curv_unique)
        
        self.rect_max    = np.dstack(self.rect_max)
        self.rect_unique = np.dstack(self.rect_unique)
        df = pd.DataFrame({'files':self.files, 
                           'curv_values':self.curv_values, 
                           'rect_values':self.rect_values})    
        df.to_csv(f'{folder_name}_curv_rect_values.csv', index=False)
        
        
    def __read_images__(self,file_block):  
        images_list,fft_images_list = [],[]    
        for image_name in file_block:
            orig_image = skio.imread(image_name)
#             if len(orig_image.shape)==3:
#                 if orig_image.shape[2]>3:
#                     orig_image = rgb2gray(rgba2rgb(orig_image))*255
#                 else:    
            orig_image = rgb2gray(orig_image)*255

            image = resize(orig_image, (self.image_size,self.image_size))

            patch_processed_image = self.__patchnorm__(image)
            output_image          = patch_processed_image['local_normed_image']
    
            fft_image = np.fft.fft2(output_image)

            images_list.append(output_image)
            fft_images_list.append(fft_image)  
            
        self.images.append(images_list) 
        self.fft_images.append(fft_images_list) 
        return images_list, fft_images_list

    def __get_max_image__(self,fft_image_list,kernel_list):
        """image x, image y, kernel dimension, all images (4D array)"""
        all_kernels = np.dstack(kernel_list)
        
        """calcuate kernel norm for normalization"""
        all_kernels_power =  np.einsum('ijk,ijk->k',np.abs(all_kernels),np.abs(all_kernels))
        all_kernels_power =  np.sqrt(all_kernels_power)

        """stack fft image list to a 3d array"""
        fft_images        = np.dstack(fft_image_list)
        all_conved_images = np.abs(np.fft.ifft2(fft_images[:,:,np.newaxis,:]*all_kernels[:,:,:,np.newaxis],axes=(0,1)))
        all_conved_images = np.power(all_conved_images,1/2) ## power correction
        all_conved_images = all_conved_images/all_kernels_power[np.newaxis, np.newaxis,:,np.newaxis]
    
        max_images = np.max(all_conved_images,axis=2)
        return max_images  
 
    
    def __calcuate_curv_rect_values__(self, fft_image_list):
        curv_max_response = self.__get_max_image__(fft_image_list, self.kernels['curv_freq'])
        rect_max_response = self.__get_max_image__(fft_image_list, self.kernels['rect_freq'])
        
        x, y,_ = curv_max_response.shape        
        self.curv_max.append(curv_max_response) 
        self.rect_max.append(rect_max_response) 
        
        curv_unique = np.where(curv_max_response>rect_max_response, curv_max_response, 0)
        curv_values = np.einsum('ijk->k',curv_unique)
        self.curv_unique.append(curv_unique)
        
        rect_unique = np.where(rect_max_response>curv_max_response, rect_max_response, 0)
        rect_values = np.einsum('ijk->k',rect_unique)
        self.rect_unique.append(rect_unique)
        
        self.curv_values.extend(curv_values/(2*(x/2)**2))
        self.rect_values.extend(rect_values/(2*(x/2)**2))

    def show_kernel_example(self):
        fig,ax = plt.subplots(2,2,figsize=(8,6))
        ax = ax.flat
        ax[0].imshow(self.kernels['curv_space'][100])
        ax[0].set(title='curvilinear kernel')
        ax[1].imshow((np.log(np.abs(self.kernels['curv_freq'][100]))))
        ax[1].set(title='power of the curvilinear kernel')
                                
        ax[2].imshow(self.kernels['rect_space'][15])
        ax[2].set(title='rectilinear kernel')
        ax[3].imshow((np.log(np.abs(self.kernels['rect_freq'][15]))))
        ax[3].set(title='power of the curvilinear kernel')
        
        plt.tight_layout()
        plt.show()

    def show_curv_rect_example(self):
        fig, ax = plt.subplots(2,2,figsize=(8,6))
        ax = ax.flat
        ax[0].imshow(np.fft.fftshift(self.curv_max[:,:,0]))
        ax[0].set(title='a max curvilinear image')
        
        ax[1].imshow(np.fft.fftshift(self.rect_max[:,:,0]))
        ax[1].set(title='a max rectilinear image')
        
        ax[2].imshow(np.fft.fftshift(self.curv_unique[:,:,0]))
        ax[2].set(title='a unique curvilinear image')
        
        ax[3].imshow(np.fft.fftshift(self.rect_unique[:,:,0]))
        ax[3].set(title='a unique rectilinear image')
        plt.tight_layout()
        plt.show()
                               
            
    def show_correlation(self):
        tmp_corr = stats.pearsonr(self.curv_values, self.rect_values)
        fig,ax = plt.subplots(1,1)
        ax.scatter(self.curv_values,self.rect_values)
        ax.set(xlabel='curvilinear values',
               ylabel='rectilinear values',
               title=f'correlation:{tmp_corr[0]:.3f}')
        
        plt.tight_layout()
        plt.show()
        

In [None]:
def test():
    file_list = glob.glob('examples/*.png')
    curvrect_score= CurvRectValues()
    curvrect_score.processing_images(file_list)
    curvrect_score.show_correlation()

In [None]:
if __name__ == "__main__":
    test()