# Wavelet Pakcet Decomposition
## Imports

In [1]:
import tensorflow as tf

## Test input

In [2]:
input_sig = [32, 10, 20, 38, 37, 28, 38, 34, 18, 24, 18, 9, 23, 24, 28, 34]
tf_sig = tf.constant(input_sig, dtype=tf.float32)

## The HaarClassicWPD class
This class implements the Wavelet Packet Decomposition using the Tensorflow package.
There are two public methods:

### get_level(signal, level)
Computes the WPD using the Haar Classic wavelet function.

Inputs:

*signal* is a 1D Tensor containing the signal

*level* is the desired level of decomposition

Ouput:

A 2D Tensor containing the decomposition sub-signals.

### get_features_level(signal, level)
Uses the *get_level* method to get the decomposition sub-signals and then computes the features.

In [119]:
class HaarClassicWPD:    
    @staticmethod
    def __log2(self, x):
        numerator = tf.math.log(x)
        denominator = tf.math.log(tf.constant(2, dtype=numerator.dtype))
        return numerator / denominator
    
    @staticmethod
    def __both_filters(signal):
        even_odds = tf.transpose(tf.reshape(signal, [tf.shape(signal)[0]//2, 2]))
        return tf.stack([tf.math.subtract(even_odds[0], even_odds[1]), tf.reduce_sum(even_odds, 0)], 0)/2
    
    @staticmethod
    def __high_pass_filter(signal):
        return tf.math.divide(
                   tf.reduce_sum(
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2])
                       ), 0
                   ), 2
               )

    @staticmethod
    def __low_pass_filter(signal):
        return tf.math.divide(
                   tf.math.subtract(
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2]))[0],
                       tf.transpose(
                           tf.reshape(signal, [tf.shape(signal)[0]//2, 2]))[1]
                    ), 2
                )
    
    @staticmethod
    def __sig_to_feature(signal):
        return tf.reduce_logsumexp(signal)
    
    @staticmethod
    def get_level(signal, level):
        signal = tf.reshape(signal, [1, tf.size(signal)])
        curr_level = 1
        while curr_level <= level:
            signal = tf.map_fn(HaarClassicWPD.__both_filters, signal)
            sig_shape = tf.shape(signal)
            signal = tf.reshape(signal, [sig_shape[0]*sig_shape[1], sig_shape[2]])
            curr_level += 1
        return signal
    
    @staticmethod
    def get_features_level(signal, level):
        return tf.map_fn(HaarClassicWPD.__sig_to_feature, HaarClassicWPD.get_level(signal, level))

In [122]:
with tf.Session().as_default():
    level = 3
    print(HaarClassicWPD.get_level(tf_sig, level).eval())
    #print(HaarClassicWPD.get_features_level(tf_sig, level).eval())

[[ 4.375 -2.5  ]
 [ 5.625 -1.25 ]
 [-1.125  1.25 ]
 [ 2.125 -0.5  ]
 [-1.125  3.75 ]
 [-2.875  0.   ]
 [-4.625 -5.   ]
 [29.625 22.25 ]]
