In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest

# from torch.nn.parameter import Parameter

#PyWavelets package
import pywt

In [18]:
class HaarDWT(nn.Module):
    """Short summary.

    Attributes
    ----------
    c_filter : pytorch Tensor
        Filter to produce c_i coefficients in DWT. Sometimes referred to as h_0.
    d_filter : type
        Filter to produce d_i coefficients in DWT. Sometimes referred to as h_1.
    padder : pytorch nn Module
        Appropriately pads odd-length arrays
    level : int
        Level of the DWT
    """
    def __init__(self, level=None, input_len=None):
        super().__init__()
        self.filter_len = 2
        if level is None and input_len is None:
            raise ValueError("A level or input length must be specified")
        if input_len is not None:
            self.level = self.max_dwt_level(input_len)
        else:
            self.level = level
        self.c_filter = torch.tensor(np.divide(np.array([1., 1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,1,2))
        self.d_filter = torch.tensor(np.divide(np.array([1., -1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,1,2))

        # If the input array is odd-length, pad the array by repeating the
        # last element
        self.padder = nn.ReplicationPad1d((0,1))

    def max_dwt_level(self, data_len):
        """
        This is a function to compute the maximum level DWT that is possible
        on a 1D input of length data_len. This formula is copied from
        PyWavelets: https://tinyurl.com/y9u7yvbw
        """
        return int(np.floor(np.log2(data_len / (self.filter_len - 1))))

    def filter(self, x):
        # To get odd-length arrays matching pywt output, we need to repeat the
        # last element of the array
        if x.shape[-1] % 2:
            x = self.padder(x)

        c_out = F.conv1d(x, self.c_filter, stride=2)
        d_out = F.conv1d(x, self.d_filter, stride=2)
        out = torch.cat((c_out, d_out), axis=-1)
        return out


    def forward(self, x):
        """
        Expects input of size (batch_n, 1, xlen) where
            batch_num unconstrained and xlen is a power of 2
        """
        xlen = x.size()[-1]

        # This is bit manipulation to assert that xlen is a power of 2
        if not ((xlen & (xlen-1) == 0) and xlen != 0):
            raise ValueError("Input array length {} is not power of 2".format(xlen))
        if self.level > self.max_dwt_level(xlen):
            raise ValueError("Input array length {} gives max DWT level {}".format(xlen, self.max_dwt_level(xlen)))

        for l in range(self.level):
            level_i_arr_idx = int(xlen / (2 ** l))
            x_in = x[:,:,:level_i_arr_idx]
            x[:,:,:level_i_arr_idx] = self.filter(x_in)
        return x

class IHaarDWT(nn.Module):
    """Short summary.

    Attributes
    ----------
    c_filter : pytorch Tensor
        Filter to produce c_i coefficients in DWT. Sometimes referred to as h_0.
    d_filter : type
        Filter to produce d_i coefficients in DWT. Sometimes referred to as h_1.
    padder : pytorch nn Module
        Appropriately pads odd-length arrays
    level : int
        Level of the DWT
    """
    def __init__(self, level=None, input_len=None):
        super().__init__()
        self.filter_len = 2
        if level is None and input_len is None:
            raise ValueError("A level or input length must be specified")
        if input_len is not None:
            self.level = self.max_dwt_level(input_len)
        else:
            self.level = level
        self.c_filter = torch.tensor(np.divide(np.array([1., 1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,2))
        self.d_filter = torch.tensor(np.divide(np.array([1., -1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,2))

    def max_dwt_level(self, data_len):
        """
        This is a function to compute the maximum level DWT that is possible
        on a 1D input of length data_len. This formula is copied from
        PyWavelets: https://tinyurl.com/y9u7yvbw
        """
        return int(np.floor(np.log2(data_len / (self.filter_len - 1))))

    def unfilter(self, x):
        xlen = x.size()[-1]
        batch_num = x.size()[0]
        x = x.reshape((batch_num, 2, int(xlen / 2))).permute(0,2,1)
        # Does x A^T
        c_out = F.linear(x, self.c_filter)
        d_out = F.linear(x, self.d_filter)
        out = torch.cat((c_out, d_out), axis=2) #.flatten(start_dim=1)
        out = out.reshape((batch_num, 1, xlen))
        return out

    def forward(self, x):
        """
        Expects input of size (batch_n, 1, xlen) where
            batch_num unconstrained and xlen is a power of 2
        """
        # x has shape (batch_num,1,xlen). We need to transform it into (2,xlen / 2)
        xlen = x.size()[-1]

        if xlen % 2:
            raise ValueError("Expected even-length input but recieved length {}".format(xlen))
        for l in range(self.level-1, -1, -1):
            level_i_arr_idx = int(xlen / (2 ** l))
            x_in = x[:,:,:level_i_arr_idx]
            x[:,:,:level_i_arr_idx] = self.unfilter(x_in)
        return x


In [61]:
class TestWaveletBlock1d(unittest.TestCase):
    def setUp(self):
        # a_arrays = tested against pywt reference for DWT and IDWT
        self.a_arrays = {}
        self.a_arrays['000_ones'] = np.ones((1,1,128))
        self.a_arrays['001_ones'] = np.ones((1,1,256))
        self.a_arrays['002_zeros'] = np.zeros((1,1,128))
        self.a_arrays['003_zeros'] = np.zeros((1,1,256))
        self.a_arrays['004_linspace'] = np.linspace(0, 10, 512).reshape((1,1,512))
        self.a_arrays['005_linspace'] = np.linspace(0, 10, 1024).reshape((1,1,1024))
        self.a_arrays['006_sin'] = np.sin(np.linspace(0, 10, 512)).reshape((1,1,512))
        self.a_arrays['007_cos'] = np.cos(np.linspace(0, 10, 1024)).reshape((1,1,1024))
        self.a_arrays['008_abs'] = np.abs(np.linspace(-33, 33, 1024)).reshape((1,1,1024))

        # b_arrays = tested by taking DWT and then taking IDWT and comparing
        # the original and output
        self.b_arrays = {}
        self.b_arrays['009_runif'] = np.random.uniform(size=10240).reshape((10,1,1024))
        self.b_arrays['010_runif'] = np.random.uniform(size=5120).reshape((10,1,512))
        self.b_arrays['011_runif'] = np.random.uniform(size=11520).reshape((90,1,128))
        self.b_arrays['010_rnorm'] = np.random.normal(size=10240).reshape((10,1,1024))

    @staticmethod
    def torch_to_numpy(ten):
        return ten.detach().numpy()

    @staticmethod
    def numpy_to_torch(arr):
        
        return torch.from_numpy(arr).float()

        
    def assert_close_and_show_diff(self, ref, ans):
        errors = ref - ans
        errors = errors.flatten()
        l_inf = np.linalg.norm(errors, ord=np.inf)
        l_2 = np.linalg.norm(errors, ord=2)
        s = "Errors: L_2 = {:.03e} L_inf = {:.03e}".format(l_2, l_inf)

        # the atol is a bit high here (I set it higher than default to pass
        # the tests), but there are some numerical errors
        # in torch that I can't figure out how to get around without using
        # double precision
        self.assertTrue(np.allclose(ref, ans, atol=5e-06), s)

    def test_WaveletBlock1d_init(self):
        WB_obj = WaveletBlock1d(input_len=32, width=4, keep=16)
        self.assertIsInstance(WB_obj, WaveletBlock1d)
        
    def get_identity_transform(self, arr, width):
        input_len = arr.shape[-1]
        WB_obj = WaveletBlock1d(input_len=input_len, width=width, keep=input_len)
        
        # TODO: Figure out what identity_matrix_weight needs to be.
        # It should be of shape (width, width, input_len)
        identity_matrix_weights = nn.Parameter(torch.eye(width).reshape(1,width,width))
        WB_obj.weights = identity_matrix_weights
        out_ten = WB_obj(self.numpy_to_torch(arr))
        return self.torch_to_numpy(out_ten)
        
        
    def test_identity(self):
        for k,v in self.a_arrays.items():
            for width in [4,8,16]:
                with self.subTest(k=k, width=width):
                    identity_transform_out = self.get_identity_transform(v, width)
                    self.assert_close_and_show_diff(v, identity_transform_out)


In [62]:
class WaveletBlock1d(nn.Module):
    def __init__(self, input_len, width, keep):
        super(WaveletBlock1d, self).__init__()
        self.input_len = input_len
        self.width = width
        self.keep = keep
        
        self.DWT = HaarDWT(input_len=input_len)
        self.IDWT = IHaarDWT(input_len=input_len)
        self.linear_layer = nn.Conv1d(self.width, self.width, 1, bias=None)
        
        self.scale = (1 / (self.width**2))
        self.weights = nn.Parameter(self.scale * torch.rand(self.width, 
                                                            self.width,
                                                            self.keep))
        
    def forward(self, x):
        print(x.size())
        # x has shape (batch_size, width, input_len)
        # Do DWT row-by-row
        z = torch.zeros(x.size())
        xlen = x.size()[-1]
        for i, row in enumerate(x.split(1, dim=1)): #row is a slice of the width axis
            out = self.DWT(row)
            
            # We want to keep only the specified number of coefficients.
            # The high-DWT-level coefficients are towards the left of the array. 
            z[:,i,:self.keep] = out.view(-1, self.keep)  

        # ok so now z has the DWT coefficients. z is of shape (batch_size, width, input_len)
        z = torch.einsum('bix,iox->box',z[:,:,:self.keep], self.weights)
        
        out = torch.zeros(z.size())
        for i, row in enumerate(z.split(1, dim=1)):
            print(row.size())
            idwt_out = self.IDWT(row)
            out[:, i, :] = idwt_out.view(-1, xlen)
        return out

In [63]:
test_obj = TestWaveletBlock1d()
test_obj.setUp()

In [64]:
ref = np.ones((1,3,8))
# ref_ten = test_obj.numpy_to_torch(ref)
ans = test_obj.get_identity_transform(ref, 3)
print("REF SHAPE: {}, ANS SHAPE: {}".format(ref.shape, ans.shape))
print("ANS: {}".format(ans))
test_obj.assert_close_and_show_diff(ref, ans)

torch.Size([1, 3, 8])


RuntimeError: size of dimension does not match previous size, operand 1, dim 0

In [33]:
c_filter = torch.divide(torch.tensor([1., 1.]), torch.sqrt(torch.tensor(2.))).reshape((1,2))
tensor_print(c_filter, "c_filter")
x = torch.sqrt(torch.tensor([[2,2]], dtype=torch.float)).reshape((2,1))
# print(x)
tensor_print(x, 'x')
tensor_print(torch.mm(x, c_filter))

NameError: name 'tensor_print' is not defined

In [78]:
inv_sqrt_2 = torch.divide(torch.tensor(1.), torch.sqrt(torch.tensor(2.)))
sqrt_2 = torch.sqrt(torch.tensor(2.))
out = torch.multiply(sqrt_2, inv_sqrt_2)
print(out, float(out))

tensor(1.0000) 0.9999999403953552


In [9]:
print(ref)
print(ans)
print(ref - ans)

[[[1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0.99999994 0.99999994 0.99999994 0.99999994 0.99999994 0.99999994
   0.99999994 0.99999994]]]
[[[5.9604645e-08 5.9604645e-08 5.9604645e-08 5.9604645e-08 5.9604645e-08
   5.9604645e-08 5.9604645e-08 5.9604645e-08]]]


In [31]:
suite = unittest.TestSuite()
suite.addTest(TestHaarDWT("test_dwt_against_reference"))
suite.addTest(TestHaarDWT("test_idwt_against_reference"))
suite.addTest(TestHaarDWT("test_invertibility"))
runner = unittest.TextTestRunner()
runner.run(suite)

DWT_REF: [[[1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 1.41421356 1.41421356
   1.41421356 1.41421356 1.41421356 1.41421356 0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.         0.        

.

DWT_REF: [[[ 1.93727885e-01  6.36534480e-01  1.07934108e+00  1.52214767e+00
    1.96495426e+00  2.40776086e+00  2.85056745e+00  3.29337405e+00
    3.73618064e+00  4.17898724e+00  4.62179383e+00  5.06460043e+00
    5.50740702e+00  5.95021362e+00  6.39302021e+00  6.83582681e+00
    7.27863340e+00  7.72144000e+00  8.16424659e+00  8.60705319e+00
    9.04985978e+00  9.49266638e+00  9.93547297e+00  1.03782796e+01
    1.08210862e+01  1.12638928e+01  1.17066994e+01  1.21495059e+01
    1.25923125e+01  1.30351191e+01  1.34779257e+01  1.39207323e+01
    1.43635389e+01  1.48063455e+01  1.52491521e+01  1.56919587e+01
    1.61347653e+01  1.65775719e+01  1.70203785e+01  1.74631851e+01
    1.79059917e+01  1.83487983e+01  1.87916049e+01  1.92344115e+01
    1.96772181e+01  2.01200247e+01  2.05628312e+01  2.10056378e+01
    2.14484444e+01  2.18912510e+01  2.23340576e+01  2.27768642e+01
    2.32196708e+01  2.36624774e+01  2.41052840e+01  2.45480906e+01
    2.49908972e+01  2.54337038e+01  2.58765104e+01  2


FAIL: test_dwt_against_reference (__main__.TestHaarDWT) (k='000_ones', level=2)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-26-5a6879821525>", line 116, in test_dwt_against_reference
    self.assert_close_and_show_diff(dwt_ref, dwt_ans)
  File "<ipython-input-26-5a6879821525>", line 65, in assert_close_and_show_diff
    self.assertTrue(np.allclose(ref, ans), s)
AssertionError: False is not true : Errors: L_2 = 6.895e-07 L_inf = 1.192e-07

FAIL: test_dwt_against_reference (__main__.TestHaarDWT) (k='000_ones', level=3)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-26-5a6879821525>", line 116, in test_dwt_against_reference
    self.assert_close_and_show_diff(dwt_ref, dwt_ans)
  File "<ipython-input-26-5a6879821525>", line 65, in assert_close_and_show_diff
    self.assertTrue(np.allclose(ref, ans), s)
AssertionError: False 

<unittest.runner.TextTestResult run=3 errors=0 failures=42>

In [88]:
def pywt_DWT(a):
    c,d = pywt.dwt(a, 'haar')
    return np.concatenate((c,d))
class HaarDWT(nn.Module):
    def __init__(self):
        super().__init__()
        self.c_filter = torch.tensor(np.divide(np.array([1., 1.]), np.sqrt(2))).reshape((1,1,2))
        self.d_filter = torch.tensor(np.divide(np.array([1., -1.]), np.sqrt(2))).reshape((1,1,2))
        
        # If the input array is odd-length, pad the array by repeating the last element
        self.padder = nn.ReplicationPad1d((0,1))
    def forward(self, x):
        # To get odd-length arrays matching pywt output, we need to repeat the last element of the array
        if x.shape[-1] % 2:
            x = self.padder(x)
        c_out = F.conv1d(x, self.c_filter, stride=2)
        d_out = F.conv1d(x, self.d_filter, stride=2)
        out = torch.cat((c_out, d_out), axis=-1)
        return out

In [89]:
def compare_DWT(x):
    ref = pywt_DWT(x)
#     print(ref)
    x_pt = torch.tensor(x).reshape(1,1,x.shape[-1])
    ans_pt = pt_DWT().forward(x_pt)
    ans = ans_pt.numpy().reshape(ref.shape)
#     print(ans)
    if not np.allclose(ref,ans):
        print(ref)
        print(ans)
        print(ref - ans)
    else:
        print("Passed")

In [102]:
def run_dn_tests(full=True, verbose=False):
    d_arrays = {}
    d_arrays['000_runif_33_128'] = np.random.uniform(size=(33,1,128))
    d_arrays['001_runif_33_128'] = np.random.uniform(size=(33,1,128))
    d_arrays['002_runif_33_128'] = np.random.uniform(size=(33,1,128))

    v_arrays = {}
    v_arrays['000_ints_1_1_8'] = np.array([[[1,2,3,4,5,6,7,8]]])
    v_arrays['000_ones_1_1_8'] = np.ones((1,1,8))
    v_arrays['000_ones_3_4_8'] = np.ones((3,1,8))
    v_arrays['000_runif_4_8'] = np.random.uniform(size=(3,1,8))
    
    if verbose:
        for k, v in v_arrays.items():
            for level in [1,2]:
                v_ten = numpy_to_torch(v)
                print("INPUT ARRAY: {}".format(v))
#                 dwt_ref = np.array([pywt_reference_HaarDWT_leveln(i, n=level) for i in v])
#                 print("DWT REF: {}".format(dwt_ref))
                DWT_obj = HaarDWT(level=level)
                IDWT_obj = IHaarDWT(level=level)
                dwt_ans = DWT_obj(v_ten)
                print("DWT ANS: {}".format(dwt_ans))
#                 idwt_ans = torch_to_numpy(IDWT_obj(dwt_ans))
                ans = torch_to_numpy(IDWT_obj(dwt_ans))
                print("ANS: {}".format(ans))
                finish_str = k + " level_{}".format(level)
                try: 
                    b = np.allclose(v, ans)
                    if b:
                        finish_str += ": PASS"
                    else:
                        finish_str += ": FAIL"

                    print("ARR - ANS: {}".format(v - ans))
                except ValueError:
                    finish_str += ": FAIL: INCOMPATIBLE DIMENSIONS"
                print(finish_str)
    if full:
        for k, v in d_arrays.items():
            for level in [1,2,3]:
                v_ten = numpy_to_torch(v)
                DWT_obj = HaarDWT(level=level)
                IDWT_obj = IHaarDWT(level=level)
                dwt = DWT_obj(v_ten)
                ans = torch_to_numpy(IDWT_obj(dwt))
                finish_str = k + " level_{}".format(level)
                try: 
                    b = np.allclose(v, ans)
                    if b:
                        finish_str += ": PASS"
                    else:
                        finish_str += ": FAIL"

                except ValueError:
                    finish_str += ": FAIL: INCOMPATIBLE DIMENSIONS"
                print(finish_str)

In [90]:
test_arr = []
# test_arr.append(np.sin(np.linspace(-7, 1, 100)))
# test_arr.append(np.power(np.linspace(-1, 1, 1000), 3))
test_arr.append(np.linspace(-1, 3, 5))
# test_arr.append(np.array([1,1,1,1,1], dtype=float))

for i in test_arr:
    compare_DWT(i)


NEW X: tensor([[[-1.,  0.,  1.,  2.,  3.,  3.]]], dtype=torch.float64)
Passed


In [29]:
plt.plot(x, label='x')
plt.plot(x_recons, label='reconstruction')
plt.plot(x - x_recons, label='diff')
plt.legend()
plt.show()

NameError: name 'x' is not defined

In [10]:
print(det_c)

[ 0.02828139  0.02712704  0.0259727   0.02481836  0.02366402  0.02250967
  0.02135533  0.02020099  0.01904665  0.01789231  0.01673796  0.01558362
  0.01442928  0.01327494  0.01212059  0.01096625  0.00981191  0.00865757
  0.00750322  0.00634888  0.00519454  0.0040402   0.00288586  0.00173151
  0.00057717 -0.00057717 -0.00173151 -0.00288586 -0.0040402  -0.00519454
 -0.00634888 -0.00750322 -0.00865757 -0.00981191 -0.01096625 -0.01212059
 -0.01327494 -0.01442928 -0.01558362 -0.01673796 -0.01789231 -0.01904665
 -0.02020099 -0.02135533 -0.02250967 -0.02366402 -0.02481836 -0.0259727
 -0.02712704 -0.02828139]


In [6]:
import torch.nn as nn
import pywt
import pytorch_wavelets.dwt.lowlevel as lowlevel
import pytorch_wavelets
import torch
import numpy as np

class myDWTForward(nn.Module):
    """ Performs a 2d DWT Forward decomposition of an image

    Args:
        J (int): Number of levels of decomposition
        wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to
            pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class,
            or can be a two tuple of array-like objects for the analysis low and
            high pass filters.
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The
            padding scheme
        separable (bool): whether to do the filtering separably or not (the
            naive implementation can be faster on a gpu).
        """
    def __init__(self, J=1, wave='db1', mode='zero'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.register_buffer('h0_col', filts[0])
        self.register_buffer('h1_col', filts[1])
        self.register_buffer('h0_row', filts[2])
        self.register_buffer('h1_row', filts[3])
        self.J = J
        self.mode = mode

    def forward(self, x):
        """ Forward pass of the DWT.

        Args:
            x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Returns:
            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh)
                coefficients. yh is a list of length J with the first entry
                being the finest scale coefficients. yl has shape
                :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
                :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
                dimension in yh iterates over the LH, HL and HH coefficients.

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.
        """
        yh = []
        ll = x
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            # Do 1 level of the transform
            ll, high = lowlevel.AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode)
            yh.append(high)

        return ll, yh



class myDWTInverse(nn.Module):
    """ Performs a 2d DWT Inverse reconstruction of an image

    Args:
        wave (str or pywt.Wavelet): Which wavelet to use
        C: deprecated, will be removed in future
    """
    def __init__(self, wave='db1', mode='zero'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_col, g1_col = wave.rec_lo, wave.rec_hi
            g0_row, g1_row = g0_col, g1_col
        else:
            if len(wave) == 2:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = g0_col, g1_col
            elif len(wave) == 4:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = wave[2], wave[3]
        # Prepare the filters
        filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
        self.register_buffer('g0_col', filts[0])
        self.register_buffer('g1_col', filts[1])
        self.register_buffer('g0_row', filts[2])
        self.register_buffer('g1_row', filts[3])
        self.mode = mode

    def forward(self, coeffs):
        """
        Args:
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
              yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
              W_{in}')` and yh is a list of bandpass tensors of shape
              :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
              the format returned by DWTForward

        Returns:
            Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.

        Note:
            Can have None for any of the highpass scales and will treat the
            values as zeros (not in an efficient way though).
        """
        yl, yh = coeffs
        ll = yl
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel inverse transform
        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
                                ll.shape[-1], device=ll.device)

            # 'Unpad' added dimensions
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[...,:-1,:]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[...,:-1]
            ll = lowlevel.SFB2D.apply(
                ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
        return ll

In [14]:
x = torch.tensor(np.ones((100,1)).dot(np.linspace(0, 99, 100).reshape((1,100))), requires_grad=True).reshape((1, 100,100))
print(x.size())

torch.Size([1, 100, 100])


In [15]:
DWT = pytorch_wavelets.DWTForward()

In [16]:
out_tens = DWT(x)

IndexError: tuple index out of range