Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I have a question to ask you #1

Closed
happycaoyue opened this issue Jan 4, 2019 · 9 comments
Closed

I have a question to ask you #1

happycaoyue opened this issue Jan 4, 2019 · 9 comments

Comments

@happycaoyue
Copy link

in pytorch_wavelets.dwt.transform2d.py you use filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], hl[None,::-1,::-1], hh[None,::-1,::-1]], axis=0)

I don't konw why make ll lh hl hh change ?

@fbcotter
Copy link
Owner

fbcotter commented Jan 4, 2019

Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters

@happycaoyue
Copy link
Author

Do you mean why I take the negative strides? It's because pytorch conv2d does cross-correlation rather than true convolution, so to prepare for this later, I flip the filters

Thanks for your answer,in DWTForward ,you flip the filters,but in DWTInverse you don't flip the filters

@happycaoyue
Copy link
Author

happycaoyue commented Jan 5, 2019

import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F


def dwt(x):
    ll = np.array([[0.5, 0.5], [0.5, 0.5]])
    lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
    hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
    hh = np.array([[0.5, -0.5], [-0.5, 0.5]])

    filts = np.stack([ll[None,], lh[None,],
                      hl[None,], hh[None,]],
                     axis=0)
    filts = np.copy(filts)
    weight = nn.Parameter(
        torch.tensor(filts).to(torch.get_default_dtype()),
        requires_grad=False)
    C = x.shape[1]
    filters = torch.cat([weight, ] * C, dim=0)
    xs = torch.from_numpy(x).to(torch.float)
    y = F.conv2d(xs, filters, groups=C, stride=2)

    return y.numpy()

def idwt(y):
    ll = np.array([[0.5, 0.5], [0.5, 0.5]])
    lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
    hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
    hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
    filts = np.stack([ll[None,], lh[None,],
                      hl[None,], hh[None,]],
                     axis=0)

    filts = np.copy(filts)
    weight = nn.Parameter(
        torch.tensor(filts).to(torch.get_default_dtype()),
        requires_grad=False)

    C = int(y.shape[1] / 4)
    filters = torch.cat([weight, ] * C, dim=0)
    ys = torch.from_numpy(y).to(torch.float)
    b = F.conv_transpose2d(ys, filters, groups=C, stride=2)
    return b.numpy()

x = [[1,2,3,4] , [5,6,7,8] , [9,10,11,12] , [13,14,15,16]]
x = np.array(x)
x = np.expand_dims(x, 0)
x = np.expand_dims(x, 0)


d1 = dwt(x)
d2 = dwt(d1)
i2 = idwt(d2)
i1 = idwt(i2)
print(i1)

>>>[[[[ 1.  2.  3.  4.]
>>>   [ 5.  6.  7.  8.]
>>>   [ 9. 10. 11. 12.]
>>>   [13. 14. 15. 16.]]]]

i change your code to test ,i use haar Wavelet

w = pywt.Wavelet('haar')
ll = np.outer(w.dec_lo, w.dec_lo)
lh = np.outer(w.dec_hi, w.dec_lo)
hl = np.outer(w.dec_lo, w.dec_hi)
hh = np.outer(w.dec_hi, w.dec_hi)

if both DWT and IDWT all flip the filters or all don't flip the filters we can get the output same as input, Can you tell me why,thank you

@fbcotter
Copy link
Owner

fbcotter commented Jan 5, 2019

I don't flip the filters in the inverse because I use conv_transpose2d, which does true convolution. The end result being that both the forward and inverse do proper convolution with non-flipped filters.

Your code example works because you've effectively swapped the analysis and synthesis filters. If you run dwt(x) and compare the output to pywt.dwt2(x, 'haar') you will see that your wavelet coefficients are wrong.

Here's an extra bit of info which might explain it a bit more:
Note that for the haar, like all orthogonal wavelets, analysis = flipped(synthesis).

Your dwt: filters are correct, but you use correlation rather than convolution = convolution with flipped analysis filters = convolution with synthesis filters.
Your iwt: you use analysis with true convolution
Result: analysis and synthesis are swapped

@fbcotter fbcotter closed this as completed Jan 5, 2019
@happycaoyue
Copy link
Author

happycaoyue commented Jan 5, 2019

Thanks for your reply
I want to build a nerwork DWT-CNN-IDWT,that is to say the cnn fit in the haar wavele domain
have i make some mistakes?

class DWTForward(nn.Module):

    def __init__(self):
        super().__init__()
        ll = np.array([[0.5, 0.5], [0.5, 0.5]])
        lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
        hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
        hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
        filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1],
                          hl[None,::-1,::-1], hh[None,::-1,::-1]],
                         axis=0)
        self.weight = nn.Parameter(
            torch.tensor(filts).to(torch.get_default_dtype()),
            requires_grad=False)

    def forward(self, x):

        C = x.shape[1]
        filters = torch.cat([self.weight,] * C, dim=0)

        y = F.conv2d(x, filters, groups=C, stride=2)

        return y


class DWTInverse(nn.Module):
    def __init__(self):
        super().__init__()
        ll = np.array([[0.5, 0.5], [0.5, 0.5]])
        lh = np.array([[-0.5, -0.5], [0.5, 0.5]])
        hl = np.array([[-0.5, 0.5], [-0.5, 0.5]])
        hh = np.array([[0.5, -0.5], [-0.5, 0.5]])
        filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1],
                          hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]],
                         axis=0)
        self.weight = nn.Parameter(
            torch.tensor(filts).to(torch.get_default_dtype()),
            requires_grad=False)

    def forward(self, x):
        C = int(x.shape[1] / 4)
        filters = torch.cat([self.weight, ] * C, dim=0)
        y = F.conv_transpose2d(x, filters, groups=C, stride=2)
        return y

@fbcotter
Copy link
Owner

fbcotter commented Jan 8, 2019

That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:

from pytorch_wavelets import DWT, IDWT
import torch.nn as nn

class Layer(nn.Module):
    def __init__(self, C, F):
        self.dwt = DWT(J=1, wave='haar')
        self.ll_gain = nn.Conv2d(C, F, 3, padding=1)
        self.lh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hl_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.idwt = IDWT(wave='haar')

    def forward(self, x):
        yl, yh = self.dwt(x)
        yl = self.ll_gain(yl)
        lh = self.lh_gain(yh[0][:,:,0])
        hl = self.hl_gain(yh[0][:,:,1])
        hh = self.hh_gain(yh[0][:,:,2])
        yh = (torch.stack((lh, hl, hh), dim=2), )
        y = self.idwt((yl, yh))
        return y

I wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.

@happycaoyue
Copy link
Author

That's great to hear. So you don't need to rewrite the DWTforward and DWTInverse functions. If you wanted to do something like that you could try something like the below:

from pytorch_wavelets import DWT, IDWT
import torch.nn as nn

class Layer(nn.Module):
    def __init__(self, C, F):
        self.dwt = DWT(J=1, wave='haar')
        self.ll_gain = nn.Conv2d(C, F, 3, padding=1)
        self.lh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hl_gain = nn.Conv2d(C, F, 3, padding=1)
        self.hh_gain = nn.Conv2d(C, F, 3, padding=1)
        self.idwt = IDWT(wave='haar')

    def forward(self, x):
        yl, yh = self.dwt(x)
        yl = self.ll_gain(yl)
        lh = self.lh_gain(yh[0][:,:,0])
        hl = self.hl_gain(yh[0][:,:,1])
        hh = self.hh_gain(yh[0][:,:,2])
        yh = (torch.stack((lh, hl, hh), dim=2), )
        y = self.idwt((yl, yh))
        return y

I wrote a paper recently about learning in the wavelet space, although I used a Dual Tree Complex Wavelet transform rather than the DWT with a Haar wavelet. You can see the paper behind it and the code here.

Thanks for your reply. I use deep learning to solve level-vision problem in wavelet domain. your work give me a great many idea.Thank you

@varun19299
Copy link

@fbcotter to confirm, the DWT & IDWT modules are differentiable and can use .backward etc? (seems to be true when I ran a simple test)

@varun19299
Copy link

Okay, oops, the docs state this already.

Is there a functional API available for the wavelet transforms?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants