In [196]:
import keras.backend as keras_backend
import numpy as np

from traits.api import (Bool, Enum, Float, HasStrictTraits, Int, List, Property,
                        provides, Tuple)

from blusky.wavelets.i_wavelet_2d import IWavelet2D

class Cascade2D(HasStrictTraits):
    """
    The idea here is to implement a cascade of convolvolution and modulus opertations. 
    Suppose I had a sequence of wavelets, \psi1, \psi2, ... 

    |x * \psi1| 
    |x * \psi2| -> output
        .
        .
        .
     |
     |
     ---> ||x * \psi1| * \psi2|
          ||x * \psi1| * \psi3|
                .               -> output
                .
          ||x * \psi2| * \psi3|
                .
                .
                  |
                  |
                  ---> .. etc ..    
    """
    # provide a list of wavelets to define the cascade, the order is important, the
    # wavelets are applied in order.
    wavelets = List(IWavelet2D)
    
    # The depth of the transform, how many successive conv/abs iterations to perform, 
    # this should be less than or equal to the number of wavelets supplied.
    depth = Int(2)
    
    #: Subsequent convolutions can be applied to downsampled images for efficiency
    # Provide some options with Keras for this:
    # Max - MaxPooling (take max value in a window)
    # Average - AveragePooling (average values in window)
    # Stride - Set a stride when applying the convolutions
    pooling_type = Enum(["none", "stride", "max", "average"])

    # Size of the poolin to apply at each step, "0" means no pooling, 
    # negative numbers will cause the output to be upsampled by that factor
    pooling_size = Int
    
    # shape of the input tile
    shape = Tuple(Int)
    
    #: In 2D we will apply the transform over a set of wavelets are different 
    # orientation, define that here in degrees.
    angles = Tuple
    
    def _init_weights(self, shape, dtype=None, wavelet2d=None, real_part=True):
        """
        Create an initializer for DepthwiseConv2D layers. We need these layers instead
        of Conv2D because we don't want it to stack.
        
        Parameters
        ----------
        
        wavelet2d - IWavelet2D
            An object to create a wavelet.      
        """
        if dtype is None:
            dtype = np.float32
            
        # nx/ny is the image shape, num_inp/outp are the number of
        # channels inpit/output.
        nx, ny, num_inp, num_outp = shape
        
        if num_outp != len(self.angles):
            raise RuntimeError("weights: mismatch dimension num angles.")
        
        weights = np.zeros(shape, dtype=dtype)  

        for iang, ang in enumerate(self.angles):
            wav = wavelet2d.kernel(ang)

            # keras does 32-bit real number convolutions
            if(real_part):
                x = wav.real.astype(np.float32)
            else:
                x = wav.imag.astype(np.float32)

            # we don't want to introduce a phase, put the wavelet
            # in the corner.
            x = np.roll(x, shape[0]//2, axis=1)
            x = np.roll(x, shape[1]//2, axis=0)

            # apply to each input channel
            for ichan in range(shape[2]):
                weights[:,:,ichan,iang] = x[:shape[0],:shape[1]]

        return keras_backend.variable(value=weights, dtype=dtype)        
    
    def _convolve_and_abs(self, wavelet, inp, stride=1):
        """
        Implement the operations for |x*\psi|
        """
        square = Lambda(lambda x : keras_backend.square(x), trainable = False)
        add = Add(trainable=False)
        sqrt = Lambda(lambda x : keras_backend.sqrt(x), trainable = False)

        real_part = DepthwiseConv2D(kernel_size=wavelet.shape,
                     depth_multiplier=len(self.angles),
                     data_format='channels_last',
                     padding="same", 
                     strides=stride,
                     trainable=False,
                     depthwise_initializer=lambda args : self._init_weights(args, 
                                                                            real_part=True,
                                                                            wavelet2d=wavelet))(inp)
        real_part = square(real_part)

        imag_part = DepthwiseConv2D(kernel_size=wavelet.shape,
                     depth_multiplier=len(self.angles),
                     data_format='channels_last',
                     padding="same", 
                     strides=stride, 
                     trainable=False,                                    
                     depthwise_initializer=lambda args : self._init_weights(args, 
                                                                            real_part=False,
                                                                            wavelet2d=wavelet))(inp)
        imag_part = square(imag_part)
    
        result = add([real_part, imag_part])
        return sqrt(result)
    
    def _convolve_and_pool(self, inp, wavelets):
        """
        Iterate over each wavelet and pool.
        """
        
        stride = 1
        if self.pooling_type in ("stride", "max", "average"):
            stride = self.pooling_size 
        
        if self.pooling_type == "max":
            pooling = MaxPooling2D(pool_size=(stride,stride), padding='valid', trainable=False)
            return [pooling(self._convolve_and_abs(wav, inp, stride=1)) for wav in wavelets]
        elif self.pooling_type == "average":
            pooling = AveragePooling2D(pool_size=(stride,stride), padding='valid', trainable = False)
            return [pooling(self._convolve_and_abs(wav, inp, stride=1)) for wav in wavelets]
       
        return [self._convolve_and_abs(wav, inp, stride=stride) for wav in wavelets]
        
    def transform(self, inp):
        """
        Apply abs/conv operations to arbitrary order.
        """
        
        a = self._convolve_and_pool(inp, self.wavelets)
        b0 = self._convolve_and_pool(Concatenate()(a[0]), self.wavelets[1:])
        b1 = self._convolve_and_pool(Concatenate()(a[1]), self.wavelets[2:])
        
        print (a)        
        print (b)
        return [a,b]
        """
        # FIXME - if it's multi-spectral input split into an array
        last = [inp]
        transform_order = []
        for order in range(0, self.depth):
            last = [self._convolve_and_pool(i, self.wavelets[order:]) for i in last]
            transform_order.append(last)
        
        return transform_order
        """

In [198]:
inp = Input(shape=(99,99,1))
cascade = Cascade2D(angles=(0.0, 45., 90.), wavelets=[wav1, wav2, wav3], pooling_type="max", pooling_size=2)
result = cascade.transform(inp)
#model = Model(inputs=inp, outputs=result)
#model.summary()

[<tf.Tensor 'max_pooling2d_27/MaxPool:0' shape=(?, 49, 49, 3) dtype=float32>, <tf.Tensor 'max_pooling2d_27_1/MaxPool:0' shape=(?, 49, 49, 3) dtype=float32>]
[<tf.Tensor 'max_pooling2d_28/MaxPool:0' shape=(?, 24, 24, 18) dtype=float32>]


In [169]:
from keras.layers import MaxPooling2D, AveragePooling2D, Conv2D, Concatenate, Input, DepthwiseConv2D, Concatenate, AveragePooling2D, BatchNormalization, Lambda, Add
import keras.backend as keras_backend
from keras.models import Sequential, Model

In [170]:
from blusky.wavelets.morlet2d import Morlet2D
wav = Morlet2D(sample_rate=0.004, 
               center_frequency=50., 
               bandwidth=(20.,10.), 
               crop=3.5, 
               taper=False)

In [183]:
inp = Input(shape=(99,99,1))
cascade = Cascade2D(angles=(0.0, 45., 90.))
result = cascade._convolve_and_abs(wav, inp)
model = Model(inputs=inp, outputs=result)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_73 (InputLayer)           (None, 99, 99, 1)    0                                            
__________________________________________________________________________________________________
depthwise_conv2d_135 (Depthwise (None, 99, 99, 3)    3270        input_73[0][0]                   
__________________________________________________________________________________________________
depthwise_conv2d_136 (Depthwise (None, 99, 99, 3)    3270        input_73[0][0]                   
__________________________________________________________________________________________________
lambda_128 (Lambda)             (None, 99, 99, 3)    0           depthwise_conv2d_135[0][0]       
                                                                 depthwise_conv2d_136[0][0]       
__________

In [199]:
from blusky.wavelets.morlet2d import Morlet2D

wav1 = Morlet2D(sample_rate=0.004, 
               center_frequency=50., 
               bandwidth=(40.,10.), 
               crop=3.5, 
               taper=False)

wav2 = Morlet2D(sample_rate=0.004, 
               center_frequency=25., 
               bandwidth=(20.,10.), 
               crop=3.5, 
               taper=False)

wav3 = Morlet2D(sample_rate=0.004, 
               center_frequency=25., 
               bandwidth=(10.,5.), 
               crop=3.5, 
               taper=False)


inp = Input(shape=(99,99,1))
cascade = Cascade2D(angles=(0.0, 45., 90.), wavelets=[wav1, wav2, wav3], pooling_type="max", pooling_size=2)
result = cascade._convolve_and_pool(inp, [wav1, wav2, wav3])
model = Model(inputs=inp, outputs=result)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_84 (InputLayer)           (None, 99, 99, 1)    0                                            
__________________________________________________________________________________________________
depthwise_conv2d_200 (Depthwise (None, 99, 99, 3)    3270        input_84[0][0]                   
__________________________________________________________________________________________________
depthwise_conv2d_201 (Depthwise (None, 99, 99, 3)    3270        input_84[0][0]                   
__________________________________________________________________________________________________
depthwise_conv2d_202 (Depthwise (None, 99, 99, 3)    3270        input_84[0][0]                   
__________________________________________________________________________________________________
depthwise_