Skip to content

Commit

Permalink
Merge pull request #25 from fbcotter/fergal/1d
Browse files Browse the repository at this point in the history
Add 1D Transform
  • Loading branch information
fbcotter committed Dec 28, 2020
2 parents a51adce + 5a32bdd commit edc3a61
Show file tree
Hide file tree
Showing 18 changed files with 621 additions and 87 deletions.
23 changes: 23 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ If you use this repo, please cite my PhD thesis, chapter 3: https://doi.org/10.1

__ https://github.com/kymatio/kymatio

New in version 1.3.0
~~~~~~~~~~~~~~~~~~~~

- Added 1D DWT support

.. code:: python
import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse # or simply DWT1D, IDWT1D
dwt = DWT1DForward(wave='db6', J=3)
X = torch.randn(10, 5, 100)
yl, yh = dwt(X)
print(yl.shape)
>>> torch.Size([10, 5, 22])
print(yh[0].shape)
>>> torch.Size([10, 5, 55])
print(yh[1].shape)
>>> torch.Size([10, 5, 33])
print(yh[2].shape)
>>> torch.Size([10, 5, 22])
idwt = DWT1DInverse(wave='db6')
x = idwt((yl, yh))
New in version 1.2.0
~~~~~~~~~~~~~~~~~~~~

Expand Down
14 changes: 14 additions & 0 deletions pytorch_wavelets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
'DTCWTInverse',
'DWTForward',
'DWTInverse',
'DWT1DForward',
'DWT1DInverse',
'DTCWT',
'IDTCWT',
'DWT',
'IDWT',
'DWT1D',
'DWT2D',
'IDWT1D',
'IDWT2D',
'ScatLayer',
'ScatLayerj2'
]

from pytorch_wavelets._version import __version__
from pytorch_wavelets.dtcwt.transform2d import DTCWTForward, DTCWTInverse
from pytorch_wavelets.dwt.transform2d import DWTForward, DWTInverse
from pytorch_wavelets.dwt.transform1d import DWT1DForward, DWT1DInverse
from pytorch_wavelets.scatternet import ScatLayer, ScatLayerj2

# Some aliases
DTCWT = DTCWTForward
IDTCWT = DTCWTInverse
DWT = DWTForward
IDWT = DWTInverse
DWT2D = DWT
IDWT2D = IDWT

DWT1D = DWT1DForward
IDWT1D = DWT1DInverse
2 changes: 1 addition & 1 deletion pytorch_wavelets/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# IMPORTANT: before release, remove the 'devN' tag from the release name
__version__ = '1.2.4'
__version__ = '1.3.0'
192 changes: 169 additions & 23 deletions pytorch_wavelets/dwt/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,65 @@ def backward(ctx, low, highs):
return dx, None, None, None, None, None


class AFB1D(Function):
""" Does a single level 1d wavelet decomposition of an input.
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
x (torch.Tensor): Input to decompose
h0: lowpass
h1: highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
x0: Tensor of shape (N, C, L') - lowpass
x1: Tensor of shape (N, C, L') - highpass
"""
@staticmethod
def forward(ctx, x, h0, h1, mode):
mode = int_to_mode(mode)

# Make inputs 4d
x = x[:, :, None, :]
h0 = h0[:, :, None, :]
h1 = h1[:, :, None, :]

# Save for backwards
ctx.save_for_backward(h0, h1)
ctx.shape = x.shape[3]
ctx.mode = mode

lohi = afb1d(x, h0, h1, mode=mode, dim=3)
x0 = lohi[:, ::2, 0].contiguous()
x1 = lohi[:, 1::2, 0].contiguous()
return x0, x1

@staticmethod
def backward(ctx, dx0, dx1):
dx = None
if ctx.needs_input_grad[0]:
mode = ctx.mode
h0, h1 = ctx.saved_tensors

# Make grads 4d
dx0 = dx0[:, :, None, :]
dx1 = dx1[:, :, None, :]

dx = sfb1d(dx0, dx1, h0, h1, mode=mode, dim=3)[:, :, 0]

# Check for odd input
if dx.shape[2] > ctx.shape:
dx = dx[:, :, :ctx.shape]

return dx, None, None, None, None, None


def afb2d(x, filts, mode='zero'):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
Expand Down Expand Up @@ -635,6 +694,55 @@ def backward(ctx, dy):
return dlow, dhigh, None, None, None, None, None


class SFB1D(Function):
""" Does a single level 1d wavelet decomposition of an input.
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
low (torch.Tensor): Lowpass to reconstruct of shape (N, C, L)
high (torch.Tensor): Highpass to reconstruct of shape (N, C, L)
g0: lowpass
g1: highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
y: Tensor of shape (N, C*2, L')
"""
@staticmethod
def forward(ctx, low, high, g0, g1, mode):
mode = int_to_mode(mode)
# Make into a 2d tensor with 1 row
low = low[:, :, None, :]
high = high[:, :, None, :]
g0 = g0[:, :, None, :]
g1 = g1[:, :, None, :]

ctx.mode = mode
ctx.save_for_backward(g0, g1)

return sfb1d(low, high, g0, g1, mode=mode, dim=3)[:, :, 0]

@staticmethod
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
mode = ctx.mode
g0, g1, = ctx.saved_tensors
dy = dy[:, :, None, :]

dx = afb1d(dy, g0, g1, mode=mode, dim=3)

dlow = dx[:, ::2, 0].contiguous()
dhigh = dx[:, 1::2, 0].contiguous()
return dlow, dhigh, None, None, None, None, None


def sfb2d_nonsep(coeffs, filts, mode='zero'):
""" Does a single level 2d wavelet reconstruction of wavelet coefficients.
Does not do separable filtering.
Expand Down Expand Up @@ -777,21 +885,43 @@ def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
Returns:
(g0_col, g1_col, g0_row, g1_row)
"""
g0_col = np.array(g0_col).ravel()
g1_col = np.array(g1_col).ravel()
t = torch.get_default_dtype()
g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
if g0_row is None:
g0_row = g0_col
if g1_row is None:
g1_row = g1_col
g0_col = torch.tensor(g0_col, device=device, dtype=t).reshape((1,1,-1,1))
g1_col = torch.tensor(g1_col, device=device, dtype=t).reshape((1,1,-1,1))
g0_row = torch.tensor(g0_row, device=device, dtype=t).reshape((1,1,1,-1))
g1_row = torch.tensor(g1_row, device=device, dtype=t).reshape((1,1,1,-1))
g0_row, g1_row = g0_col, g1_col
else:
g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)

g0_col = g0_col.reshape((1, 1, -1, 1))
g1_col = g1_col.reshape((1, 1, -1, 1))
g0_row = g0_row.reshape((1, 1, 1, -1))
g1_row = g1_row.reshape((1, 1, 1, -1))

return g0_col, g1_col, g0_row, g1_row


def prep_filt_sfb1d(g0, g1, device=None):
"""
Prepares the filters to be of the right form for the sfb1d function. In
particular, makes the tensors the right shape. It does not mirror image them
as as sfb2d uses conv2d_transpose which acts like normal convolution.
Inputs:
g0 (array-like): low pass filter bank
g1 (array-like): high pass filter bank
device: which device to put the tensors on to
Returns:
(g0, g1)
"""
g0 = np.array(g0).ravel()
g1 = np.array(g1).ravel()
t = torch.get_default_dtype()
g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))

return g0, g1


def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None):
"""
Prepares the filters to be of the right form for the afb2d function. In
Expand All @@ -810,20 +940,36 @@ def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None):
Returns:
(h0_col, h1_col, h0_row, h1_row)
"""
h0_col = np.array(h0_col[::-1]).ravel()
h1_col = np.array(h1_col[::-1]).ravel()
t = torch.get_default_dtype()
h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
if h0_row is None:
h0_row = h0_col
else:
h0_row = np.array(h0_row[::-1]).ravel()
if h1_row is None:
h1_row = h1_col
h0_row, h1_col = h0_col, h1_col
else:
h1_row = np.array(h1_row[::-1]).ravel()
h0_col = torch.tensor(h0_col, device=device, dtype=t).reshape((1,1,-1,1))
h1_col = torch.tensor(h1_col, device=device, dtype=t).reshape((1,1,-1,1))
h0_row = torch.tensor(h0_row, device=device, dtype=t).reshape((1,1,1,-1))
h1_row = torch.tensor(h1_row, device=device, dtype=t).reshape((1,1,1,-1))
h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)

h0_col = h0_col.reshape((1, 1, -1, 1))
h1_col = h1_col.reshape((1, 1, -1, 1))
h0_row = h0_row.reshape((1, 1, 1, -1))
h1_row = h1_row.reshape((1, 1, 1, -1))
return h0_col, h1_col, h0_row, h1_row


def prep_filt_afb1d(h0, h1, device=None):
"""
Prepares the filters to be of the right form for the afb2d function. In
particular, makes the tensors the right shape. It takes mirror images of
them as as afb2d uses conv2d which acts like normal correlation.
Inputs:
h0 (array-like): low pass column filter bank
h1 (array-like): high pass column filter bank
device: which device to put the tensors on to
Returns:
(h0, h1)
"""
h0 = np.array(h0[::-1]).ravel()
h1 = np.array(h1[::-1]).ravel()
t = torch.get_default_dtype()
h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
return h0, h1

0 comments on commit edc3a61

Please sign in to comment.