Skip to content

Commit

Permalink
Merge pull request #32 from sidward/master
Browse files Browse the repository at this point in the history
Reimplement wavelet using xp functions.
  • Loading branch information
frankong committed Nov 14, 2019
2 parents 5bd25cd + abc4464 commit 63c2e99
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 32 deletions.
11 changes: 6 additions & 5 deletions sigpy/linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,8 @@ def __init__(self, ishape, axes=None, wave_name='db4', level=None):
self.wave_name = wave_name
self.axes = axes
self.level = level
oshape, _ = wavelet.get_wavelet_shape(ishape, wave_name, axes, level)
oshape = wavelet.get_wavelet_shape(ishape, wave_name=wave_name,
axes=axes, level=level)

super().__init__(oshape, ishape)

Expand Down Expand Up @@ -1197,14 +1198,14 @@ def __init__(self, oshape, axes=None, wave_name='db4', level=None):
self.wave_name = wave_name
self.axes = axes
self.level = level
ishape, self.coeff_slices = wavelet.get_wavelet_shape(
oshape, wave_name, axes, level)
ishape = wavelet.get_wavelet_shape(oshape, wave_name=wave_name,
axes=axes, level=level)
super().__init__(oshape, ishape)

def _apply(self, input):
return wavelet.iwt(
input, self.oshape, self.coeff_slices,
wave_name=self.wave_name, axes=self.axes, level=self.level)
input, self.oshape, wave_name=self.wave_name,
axes=self.axes, level=self.level)

def _adjoint_linop(self):
return Wavelet(self.oshape, axes=self.axes,
Expand Down
234 changes: 207 additions & 27 deletions sigpy/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,242 @@
"""
import numpy as np
import pywt
from sigpy import backend, util
from sigpy import backend

__all__ = ['fwt', 'iwt']


def get_wavelet_shape(shape, wave_name, axes, level):
zshape = [((i + 1) // 2) * 2 for i in shape]
input = np.zeros(shape)
tmp = fwt(input, wave_name=wave_name, axes=axes, level=level)

tmp = pywt.wavedecn(
np.zeros(zshape), wave_name, mode='zero', axes=axes, level=level)
tmp, coeff_slices = pywt.coeffs_to_array(tmp, axes=axes)
oshape = tmp.shape
return tmp.shape

return oshape, coeff_slices

def apply_dec_along_axis(input, axes, dec_lo, dec_hi, level, apply_zpad):
"""Apply wavelet decomposition along axes.
def fwt(input, wave_name='db4', axes=None, level=None):
Helper function to recursively apply decomposition wavelet filters
along axes.
Args:
input (array): Input array.
axes (tuple of int): Axes to perform wavelet transform.
dec_lo (array): Wavelet coefficients for approximation coefficients.
dec_hi (array): Wavelet coefficients for decimation coefficients.
level (int): Level to determine amount of zero-padding.
apply_zpad (bool): Set to true to apply z-pad.
"""
assert type(axes) == tuple
assert dec_lo.shape == dec_hi.shape

if (len(axes) == 0):
return input

# Loading sigpy.
device = backend.get_device(input)
xp = device.xp

axis = axes[0]

# Zero padding.
x = input
if (apply_zpad):
pad_size = (1 + (dec_hi.size * level + x.shape[axis])//(2**level)) * \
2 ** level - x.shape[axis]
pad_array = [(0, pad_size) if k == axis else (0, 0)
for k in range(len(x.shape))]
x = xp.pad(x, pad_array, 'constant', constant_values=(0, 0))

# Fourier space.
X = xp.fft.fftn(x, axes=(axis,))

lo = xp.zeros((x.shape[axis],)).astype(xp.complex64)
lo[:dec_lo.size] = dec_lo
lo = xp.reshape(xp.fft.fftn(xp.roll(lo, -(dec_lo.size//2)), axes=(0,)),
[lo.size if k == axis else 1 for k in range(len(x.shape))])

hi = xp.zeros((x.shape[axis],)).astype(xp.complex64)
hi[:dec_hi.size] = dec_hi
hi = xp.reshape(xp.fft.fftn(xp.roll(hi, -(dec_hi.size//2)), axes=(0,)),
[hi.size if k == axis else 1 for k in range(len(x.shape))])

# Apply convolutions.
y_lo = xp.fft.ifftn(X * lo, axes=(axis,))
y_hi = xp.fft.ifftn(X * hi, axes=(axis,))

# Sub-sampling
y_lo = xp.take(y_lo, [t * 2 for t in range(0, y_lo.shape[axis]//2)],
axis=axis)
y_hi = xp.take(y_hi, [t * 2 for t in range(0, y_hi.shape[axis]//2)],
axis=axis)

# Apply recursion to other axis and concatenate.
return xp.concatenate((apply_dec_along_axis(y_lo, axes[1:], dec_lo,
dec_hi, level, apply_zpad),
apply_dec_along_axis(y_hi, axes[1:], dec_lo,
dec_hi, level, apply_zpad)), axis=axis)


def apply_rec_along_axis(input, axes, rec_lo, rec_hi):
"""Apply wavelet recomposition along axes.
Helper function to recursively apply decomposition wavelet filters
along axes. Assumes input has been appropriately zero-padded by
apply_dec_along_axis (used by fwt).
Args:
input (array): Input array.
axes (tuple of int): Axes to perform wavelet transform.
rec_lo (array): Wavelet coefficients for approximation coefficients.
rec_hi (array): Wavelet coefficients for decimation coefficients.
"""
assert type(axes) == tuple
assert rec_lo.shape == rec_hi.shape

if (len(axes) == 0):
return input

# Load sigpy.
device = backend.get_device(input)
xp = device.xp

axis = axes[0]

# Preparing filters.
lo = xp.zeros((input.shape[axis],)).astype(xp.complex64)
lo[:rec_lo.size] = rec_lo
lo = xp.reshape(xp.fft.fftn(xp.roll(lo, 1-(rec_lo.size//2)), axes=(0,)),
[lo.size if k == axis else 1
for k in range(len(input.shape))])

hi = xp.zeros((input.shape[axis],)).astype(xp.complex64)
hi[:rec_hi.size] = rec_hi
hi = xp.reshape(xp.fft.fftn(xp.roll(hi, 1-(rec_hi.size//2)), axes=(0,)),
[hi.size if k == axis else 1
for k in range(len(input.shape))])

# Coefficient indices.
lo_coeffs = tuple([slice(0, input.shape[k]//2)
if k == axis else slice(0, None)
for k in range(len(input.shape))])
hi_coeffs = tuple([slice(input.shape[k]//2, None)
if k == axis else slice(0, None)
for k in range(len(input.shape))])

# Extracting coefficients.
x_lo = xp.zeros(input.shape).astype(xp.complex64)
x_hi = xp.zeros(input.shape).astype(xp.complex64)

sample_idx = tuple([slice(0, None, 2)
if k == axis else slice(0, None)
for k in range(len(input.shape))])
x_lo[sample_idx] = input[lo_coeffs]
x_hi[sample_idx] = input[hi_coeffs]

# Apply convolutions.
X_lo = xp.fft.fftn(x_lo, axes=(axis,))
X_hi = xp.fft.fftn(x_hi, axes=(axis,))
y_lo = xp.fft.ifftn(X_lo * lo, axes=(axis,))
y_hi = xp.fft.ifftn(X_hi * hi, axes=(axis,))

# Apply recursion to other axis and concatenate.
return apply_rec_along_axis(y_lo + y_hi, axes[1:], rec_lo, rec_hi)


def fwt(input, wave_name='db4', axes=None, level=None, apply_zpad=True):
"""Forward wavelet transform.
Args:
input (array): Input array.
axes (None or tuple of int): Axes to perform wavelet transform.
wave_name (str): Wavelet name.
axes (None or tuple of int): Axes to perform wavelet transform.
level (None or int): Number of wavelet levels.
apply_zpad (bool): If true, zero-pad for linear convolution.
"""
device = backend.get_device(input)
input = backend.to_device(input, backend.cpu_device)
xp = device.xp

if axes is None:
axes = tuple([k for k in range(len(input.shape))
if input.shape[k] > 1])

if (type(axes) == int):
axes = (axes,)

wavdct = pywt.Wavelet(wave_name)
dec_lo = xp.array(wavdct.dec_lo)
dec_hi = xp.array(wavdct.dec_hi)

if level is None:
level = pywt.dwt_max_level(
xp.min(xp.array([input.shape[ax] for ax in axes])),
dec_lo.size)

if level <= 0:
return input

zshape = [((i + 1) // 2) * 2 for i in input.shape]
zinput = util.resize(input, zshape)
assert level > 0

coeffs = pywt.wavedecn(
zinput, wave_name, mode='zero', axes=axes, level=level)
output, _ = pywt.coeffs_to_array(coeffs, axes=axes)
y = apply_dec_along_axis(input, axes, dec_lo, dec_hi, level, apply_zpad)
approx_idx = tuple([slice(0, y.shape[k]//2)
if k in axes else slice(0, None)
for k in range(len(input.shape))])
y[approx_idx] = fwt(y[approx_idx], wave_name=wave_name,
axes=axes, level=level-1, apply_zpad=False)

output = backend.to_device(output, device)
return output
return y


def iwt(input, oshape, coeff_slices, wave_name='db4', axes=None, level=None):
def iwt(input, oshape, wave_name='db4', axes=None, level=None, inplace=False):
"""Inverse wavelet transform.
Args:
input (array): Input array.
oshape (tuple of ints): Output shape.
coeff_slices (list of slice): Slices to split coefficients.
axes (None or tuple of int): Axes to perform wavelet transform.
oshape (tuple): Output shape.
wave_name (str): Wavelet name.
axes (None or tuple of int): Axes to perform wavelet transform.
level (None or int): Number of wavelet levels.
inplace (bool): Modify input array in place.
"""
device = backend.get_device(input)
input = backend.to_device(input, backend.cpu_device)
xp = device.xp

if axes is None:
axes = tuple([k for k in range(len(input.shape))
if input.shape[k] > 1])

if (type(axes) == int):
axes = (axes,)

wavdct = pywt.Wavelet(wave_name)
rec_lo = xp.array(wavdct.rec_lo)
rec_hi = xp.array(wavdct.rec_hi)

if level is None:
level = pywt.dwt_max_level(
xp.min(xp.array([input.shape[ax] for ax in axes])),
rec_lo.size)

if level <= 0:
return input

assert level > 0
for ax in axes:
assert input.shape[ax] % 2 == 0

x = input if inplace else input.astype(xp.complex64).copy()

approx_idx = tuple([slice(0, input.shape[k]//2)
if k in axes else slice(0, None)
for k in range(len(input.shape))])
x[approx_idx] = iwt(x[approx_idx], input[approx_idx].shape,
wave_name=wave_name, axes=axes, level=level-1,
inplace=True)

input = pywt.array_to_coeffs(input, coeff_slices, output_format='wavedecn')
output = pywt.waverecn(input, wave_name, mode='zero', axes=axes)
output = util.resize(output, oshape)
y = apply_rec_along_axis(x, axes, rec_lo, rec_hi)
crop_idx = tuple([slice(0, oshape[k])
if k in axes else slice(0, None)
for k in range(len(input.shape))])

output = backend.to_device(output, device)
return output
return y[crop_idx]

0 comments on commit 63c2e99

Please sign in to comment.