In [1]:
import tensorflow as tf

In [2]:
input_sig = [32, 10, 20, 38, 37, 28, 38, 34, 18, 24, 18, 9, 23, 24, 28, 34]
tf_sig = tf.constant(input_sig, dtype=tf.float32)

In [3]:
class HaarClassicWPD:    
    @staticmethod
    def __log2(self, x):
        numerator = tf.math.log(x)
        denominator = tf.math.log(tf.constant(2, dtype=numerator.dtype))
        return numerator / denominator
    
    @staticmethod
    def __high_pass_filter(signal):
        return tf.math.divide(
                   tf.reduce_sum(
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2])
                       ), 0
                   ), 2
               )

    @staticmethod
    def __low_pass_filter(signal):
        return tf.math.divide(
                   tf.math.subtract(
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2]))[0],
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2]))[1]
                    ), 2
                )
    
    @staticmethod
    def __sig_to_feature(signal):
        return tf.reduce_logsumexp(signal)
    
    @staticmethod
    def get_level(signal, level):
        signal = tf.reshape(signal, [1, tf.size(signal)])
        curr_level = 1
        while curr_level <= level:
            new_sig = tf.map_fn(HaarClassicWPD.__low_pass_filter, signal)
            signal = tf.concat([new_sig, tf.map_fn(HaarClassicWPD.__high_pass_filter, signal)], 0)
            curr_level += 1
        return signal
    
    @staticmethod
    def get_features_level(signal, level):
        return tf.map_fn(HaarClassicWPD.__sig_to_feature, HaarClassicWPD.get_level(signal, level))

In [5]:
with tf.Session().as_default():
    level = 3
    print(HaarClassicWPD.get_level(tf_sig, level).eval())
    print(HaarClassicWPD.get_features_level(tf_sig, level).eval())

[[ 4.375 -2.5  ]
 [-1.125  3.75 ]
 [-1.125  1.25 ]
 [-4.625 -5.   ]
 [ 5.625 -1.25 ]
 [-2.875  0.   ]
 [ 2.125 -0.5  ]
 [29.625 22.25 ]]
[ 4.376033    3.757606    1.3389394  -4.1018767   5.626033    0.05488219
  2.1949363  29.625626  ]
