In [31]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

The scattering transform has a tree-like architecture.  The CascadeTree object defines below creates the framework of the transform, keeping track of the recipe for the convolutions and scale for each.

In [32]:
# The nodes of the cascade tree keep track of the recipe of the convolution
# (name in this case) and the scale
class Node:
    def __init__(self, name, scale, parent=None):
        self.name = name
        self.scale = scale
        self.parent = parent
        self.children = []

    def __str__(self):
        return "( %s, %d )"%(self.name, self.scale)

class CascadeTree:
    def __init__(self, order):
        self.order = order
        self.root_node = None
    
    def generate(self, wavelet_bank, input_name):
        self.root_node = Node(name=input_name, scale=0)
        current_layer = [self.root_node]
        for stage in np.arange(self.order+1):
            next_layer = []
            for current_node in current_layer:
                for i, wv in enumerate(wavelet_bank):
                    scale = i + 1
                    if scale > current_node.scale:
                        new_name = "|%s*psi_%d|"%(current_node.name, scale)
                        new_node = Node(name=new_name, scale=scale, parent=current_node)
                        current_node.children.append(new_node)
                        next_layer.append(new_node)
            current_layer = next_layer
        
    # layer by layer output    
    def display(self):
        current_layer = [self.root_node]
        
        for stage in np.arange(self.order+1):
            next_layer = []
            print("Layer %d"%stage)
            for current_node in current_layer:
                
                print(current_node)
                next_layer += current_node.children
            current_layer = next_layer
      
            
    def get_convolutions(self):
        current_layer = [self.root_node]
        all_convolutions = []
        
        #starting at the first layer, don't need the root node (input)
        for stage in np.arange(1, self.order+1):
            next_layer = []
            for current_node in current_layer:
                next_layer += current_node.children
            current_layer = next_layer
            all_convolutions += current_layer
        return all_convolutions
        
# need a method to return a list of all (non root) nodes

Test this out.  Create a wavelet bank with three wavelets. This object assumes that the wavelets are ordered in increasing scale (powers of 2).

In [35]:
wavelet_bank = ['psi_1', 'psi_2', 'psi_3']

cascade = CascadeTree(order = 2)

cascade.generate(wavelet_bank, input_name = 'x')

cascade.display()

Layer 0
( x, 0 )
Layer 1
( |x*psi_1|, 1 )
( |x*psi_2|, 2 )
( |x*psi_3|, 3 )
Layer 2
( ||x*psi_1|*psi_2|, 2 )
( ||x*psi_1|*psi_3|, 3 )
( ||x*psi_2|*psi_3|, 3 )


The `get_convolutions()` function returns a list of all the convolutions in the cascade, in scale order.

In [36]:
for conv in cascade.get_convolutions():
    print(conv)

( |x*psi_1|, 1 )
( |x*psi_2|, 2 )
( |x*psi_3|, 3 )
( ||x*psi_1|*psi_2|, 2 )
( ||x*psi_1|*psi_3|, 3 )
( ||x*psi_2|*psi_3|, 3 )
