In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pywt
import einops
from torch.autograd import Function

In [2]:
def roll(x, n, dim, make_even=False):
    if n < 0:
        n = x.shape[dim] + n

    if make_even and x.shape[dim] % 2 == 1:
        end = 1
    else:
        end = 0

    if dim == 0:
        return torch.cat((x[-n:], x[:-n+end]), dim=0)
    elif dim == 1:
        return torch.cat((x[:,-n:], x[:,:-n+end]), dim=1)
    elif dim == 2 or dim == -2:
        return torch.cat((x[:,:,-n:], x[:,:,:-n+end]), dim=2)
    elif dim == 3 or dim == -1:
        return torch.cat((x[:,:,:,-n:], x[:,:,:,:-n+end]), dim=3)

In [3]:
def prep_filt_afb1d(h0, h1, device=None):
    h0 = np.array(h0[::-1]).ravel()
    h1 = np.array(h1[::-1]).ravel()
    t = torch.get_default_dtype()
    h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
    h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
    return h0, h1

def afb1d(x, h0, h1, dim=-1):
    C = x.shape[1]
    # Convert the dim to positive
    d = dim % 4
    s = (2, 1) if d == 2 else (1, 2)
    N = x.shape[d]
    # If h0, h1 are not tensors, make them. If they are, then assume that they
    # are in the right order
    if not isinstance(h0, torch.Tensor):
        h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    if not isinstance(h1, torch.Tensor):
        h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    L = h0.numel()
    L2 = L // 2
    shape = [1,1,1,1]
    shape[d] = L
    # If h aren't in the right shape, make them so
    if h0.shape != tuple(shape):
        h0 = h0.reshape(*shape)
    if h1.shape != tuple(shape):
        h1 = h1.reshape(*shape)
    h = torch.cat([h0, h1] * C, dim=0)

    if x.shape[dim] % 2 == 1:
        if d == 2:
            x = torch.cat((x, x[:,:,-1:]), dim=2)
        else:
            x = torch.cat((x, x[:,:,:,-1:]), dim=3)
        N += 1
    x = roll(x, -L2, dim=d)
    pad = (L-1, 0) if d == 2 else (0, L-1)
    lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
    N2 = N//2
    if d == 2:
        lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
        lohi = lohi[:,:,:N2]
    else:
        lohi[:,:,:,:L2] = lohi[:,:,:,:L2] + lohi[:,:,:,N2:N2+L2]
        lohi = lohi[:,:,:,:N2]

    return lohi
        
class AFB1D(Function):
    @staticmethod
    def forward(ctx, x, h0, h1):

        # Make inputs 4d
        x = x[:, :, None, :]
        h0 = h0[:, :, None, :]
        h1 = h1[:, :, None, :]

        # Save for backwards
        ctx.save_for_backward(h0, h1)
        ctx.shape = x.shape[3]

        lohi = afb1d(x, h0, h1, dim=3)
        x0 = lohi[:, ::2, 0].contiguous()
        x1 = lohi[:, 1::2, 0].contiguous()
        return x0, x1

    @staticmethod
    def backward(ctx, dx0, dx1):
        dx = None
        if ctx.needs_input_grad[0]:
            h0, h1 = ctx.saved_tensors

            # Make grads 4d
            dx0 = dx0[:, :, None, :]
            dx1 = dx1[:, :, None, :]

            dx = sfb1d(dx0, dx1, h0, h1, dim=3)[:, :, 0]

            # Check for odd input
            if dx.shape[2] > ctx.shape:
                dx = dx[:, :, :ctx.shape]

        return dx, None, None, None, None, None

def prep_filt_sfb1d(g0, g1, device=None):
    g0 = np.array(g0).ravel()
    g1 = np.array(g1).ravel()
    t = torch.get_default_dtype()
    g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
    g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))

    return g0, g1

def sfb1d(lo, hi, g0, g1, dim=-1):
    C = lo.shape[1]
    d = dim % 4
    if not isinstance(g0, torch.Tensor):
        g0 = torch.tensor(np.copy(np.array(g0).ravel()),
                          dtype=torch.float, device=lo.device)
    if not isinstance(g1, torch.Tensor):
        g1 = torch.tensor(np.copy(np.array(g1).ravel()),
                          dtype=torch.float, device=lo.device)
    L = g0.numel()
    shape = [1,1,1,1]
    shape[d] = L
    N = 2*lo.shape[d]
    # If g aren't in the right shape, make them so
    if g0.shape != tuple(shape):
        g0 = g0.reshape(*shape)
    if g1.shape != tuple(shape):
        g1 = g1.reshape(*shape)

    s = (2, 1) if d == 2 else (1,2)
    g0 = torch.cat([g0]*C,dim=0)
    g1 = torch.cat([g1]*C,dim=0)
    y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
        F.conv_transpose2d(hi, g1, stride=s, groups=C)
    if d == 2:
        y[:,:,:L-2] = y[:,:,:L-2] + y[:,:,N:N+L-2]
        y = y[:,:,:N]
    else:
        y[:,:,:,:L-2] = y[:,:,:,:L-2] + y[:,:,:,N:N+L-2]
        y = y[:,:,:,:N]
    y = roll(y, 1-L//2, dim=dim)

    return y

class SFB1D(Function):
    @staticmethod
    def forward(ctx, low, high, g0, g1):
        # Make into a 2d tensor with 1 row
        low = low[:, :, None, :]
        high = high[:, :, None, :]
        g0 = g0[:, :, None, :]
        g1 = g1[:, :, None, :]

        ctx.save_for_backward(g0, g1)

        return sfb1d(low, high, g0, g1, dim=3)[:, :, 0]

    @staticmethod
    def backward(ctx, dy):
        dlow, dhigh = None, None
        if ctx.needs_input_grad[0]:
            g0, g1, = ctx.saved_tensors
            dy = dy[:, :, None, :]

            dx = afb1d(dy, g0, g1, dim=3)

            dlow = dx[:, ::2, 0].contiguous()
            dhigh = dx[:, 1::2, 0].contiguous()
        return dlow, dhigh, None, None, None, None, None

class DWT1DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0, h1 = wave.dec_lo, wave.dec_hi
        else:
            assert len(wave) == 2
            h0, h1 = wave[0], wave[1]
        filts = prep_filt_afb1d(h0, h1)
        self.register_buffer('h0', filts[0])
        self.register_buffer('h1', filts[1])
        self.J = J

    def forward(self, x):
        assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        highs = []
        x0 = x
        for j in range(self.J):
            x0, x1 = AFB1D.apply(x0, self.h0, self.h1)
            highs.append(x1)
        return x0, highs
        
class WPT1D(torch.nn.Module):
    def __init__(self, wt=DWT1DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt = wt
        self.J = J

    def analysis_one_level(self,x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2),H[0].unsqueeze(2)],dim=2)
        X = einops.rearrange(X, 'b c f ℓ -> b (c f) ℓ')
        return X

    def wavelet_analysis(self, x, J):
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_analysis(x, J=self.J)
        
class DWT1DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0, g1 = wave.rec_lo, wave.rec_hi
        else:
            assert len(wave) == 2
            g0, g1 = wave[0], wave[1]
        filts = prep_filt_sfb1d(g0, g1)
        self.register_buffer('g0', filts[0])
        self.register_buffer('g1', filts[1])

    def forward(self, coeffs):
        x0, highs = coeffs
        assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        for x1 in highs[::-1]:
            if x1 is None:
                x1 = torch.zeros_like(x0)
            if x0.shape[-1] > x1.shape[-1]:
                x0 = x0[..., :-1]
            x0 = SFB1D.apply(x0, x1, self.g0, self.g1)
        return x0

class IWPT1D(torch.nn.Module):
    def __init__(self, iwt=DWT1DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt = iwt
        self.J = J

    def synthesis_one_level(self, X):
        X = einops.rearrange(X, 'b (c f) ℓ -> b c f ℓ', f=2)
        L, H = torch.split(X, [1, 1], dim=2)
        L = L.squeeze(2)
        H = [H.squeeze(2)]
        y = self.iwt((L, H))
        return y

    def wavelet_synthesis(self, x, J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_synthesis(x, J=self.J)

In [4]:
x1d = torch.randn(2, 3, 4096)
wt1d = DWT1DForward(wave='bior4.4')
wpt1d = WPT1D(wt=wt1d, J=3)
iwt1d = DWT1DInverse(wave='bior4.4')
iwpt1d = IWPT1D(iwt=iwt1d, J=3)
with torch.no_grad():
    X1d = wpt1d(x1d)
    xhat1d = iwpt1d(X1d)
assert (xhat1d - x1d).abs().max() < 1e-5

In [5]:
def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None):
    h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
    if h0_row is None:
        h0_row, h1_row = h0_col, h1_col
    else:
        h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)

    h0_col = h0_col.reshape((1, 1, -1, 1))
    h1_col = h1_col.reshape((1, 1, -1, 1))
    h0_row = h0_row.reshape((1, 1, 1, -1))
    h1_row = h1_row.reshape((1, 1, 1, -1))
    return h0_col, h1_col, h0_row, h1_row

def afb2d(x, filts):
    tensorize = [not isinstance(f, torch.Tensor) for f in filts]
    if len(filts) == 2:
        h0, h1 = filts
        if True in tensorize:
            h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
                h0, h1, device=x.device)
        else:
            h0_col = h0
            h0_row = h0.transpose(2,3)
            h1_col = h1
            h1_row = h1.transpose(2,3)
    elif len(filts) == 4:
        if True in tensorize:
            h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
                *filts, device=x.device)
        else:
            h0_col, h1_col, h0_row, h1_row = filts
    else:
        raise ValueError("Unknown form for input filts")

    lohi = afb1d(x, h0_row, h1_row, dim=3)
    y = afb1d(lohi, h0_col, h1_col, dim=2)

    return y

class AFB2D(Function):
    @staticmethod
    def forward(ctx, x, h0_row, h1_row, h0_col, h1_col):
        ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
        ctx.shape = x.shape[-2:]
        lohi = afb1d(x, h0_row, h1_row, dim=3)
        y = afb1d(lohi, h0_col, h1_col, dim=2)
        s = y.shape
        y = y.reshape(s[0], -1, 4, s[-2], s[-1])
        low = y[:,:,0].contiguous()
        highs = y[:,:,1:].contiguous()
        return low, highs

    @staticmethod
    def backward(ctx, low, highs):
        dx = None
        if ctx.needs_input_grad[0]:
            h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors
            lh, hl, hh = torch.unbind(highs, dim=2)
            lo = sfb1d(low, lh, h0_col, h1_col, dim=2)
            hi = sfb1d(hl, hh, h0_col, h1_col, dim=2)
            dx = sfb1d(lo, hi, h0_row, h1_row, dim=3)
            if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:ctx.shape[-2], :ctx.shape[-1]]
            elif dx.shape[-2] > ctx.shape[-2]:
                dx = dx[:,:,:ctx.shape[-2]]
            elif dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:,:ctx.shape[-1]]
        return dx, None, None, None, None, None

def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
    g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
    if g0_row is None:
        g0_row, g1_row = g0_col, g1_col
    else:
        g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)

    g0_col = g0_col.reshape((1, 1, -1, 1))
    g1_col = g1_col.reshape((1, 1, -1, 1))
    g0_row = g0_row.reshape((1, 1, 1, -1))
    g1_row = g1_row.reshape((1, 1, 1, -1))

    return g0_col, g1_col, g0_row, g1_row

def sfb2d(ll, lh, hl, hh, filts):
    tensorize = [not isinstance(x, torch.Tensor) for x in filts]
    if len(filts) == 2:
        g0, g1 = filts
        if True in tensorize:
            g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1)
        else:
            g0_col = g0
            g0_row = g0.transpose(2,3)
            g1_col = g1
            g1_row = g1.transpose(2,3)
    elif len(filts) == 4:
        if True in tensorize:
            g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts)
        else:
            g0_col, g1_col, g0_row, g1_row = filts
    else:
        raise ValueError("Unknown form for input filts")

    lo = sfb1d(ll, lh, g0_col, g1_col, dim=2)
    hi = sfb1d(hl, hh, g0_col, g1_col, dim=2)
    y = sfb1d(lo, hi, g0_row, g1_row, dim=3)

    return y
        
class SFB2D(Function):
    @staticmethod
    def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col):
        ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)

        lh, hl, hh = torch.unbind(highs, dim=2)
        lo = sfb1d(low, lh, g0_col, g1_col, dim=2)
        hi = sfb1d(hl, hh, g0_col, g1_col, dim=2)
        y = sfb1d(lo, hi, g0_row, g1_row, dim=3)
        return y

    @staticmethod
    def backward(ctx, dy):
        dlow, dhigh = None, None
        if ctx.needs_input_grad[0]:
            g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
            dx = afb1d(dy, g0_row, g1_row, dim=3)
            dx = afb1d(dx, g0_col, g1_col, dim=2)
            s = dx.shape
            dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
            dlow = dx[:,:,0].contiguous()
            dhigh = dx[:,:,1:].contiguous()
        return dlow, dhigh, None, None, None, None, None

class DWT2DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]
        filts = prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.register_buffer('h0_col', filts[0])
        self.register_buffer('h1_col', filts[1])
        self.register_buffer('h0_row', filts[2])
        self.register_buffer('h1_row', filts[3])
        self.J = J

    def forward(self, x):
        yh = []
        ll = x
        for j in range(self.J):
            ll, high = AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row)
            yh.append(high)
        return ll, yh

class WPT2D(torch.nn.Module):
    def __init__(self, wt=DWT2DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt  = wt
        self.J = J
    def analysis_one_level(self,x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2),H[0]],dim=2)
        X = einops.rearrange(X, 'b c f h w -> b (c f) h w')
        return X
    def wavelet_analysis(self,x,J):
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x
    def forward(self, x):
        return self.wavelet_analysis(x,J=self.J)

class DWT2DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_col, g1_col = wave.rec_lo, wave.rec_hi
            g0_row, g1_row = g0_col, g1_col
        else:
            if len(wave) == 2:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = g0_col, g1_col
            elif len(wave) == 4:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = wave[2], wave[3]
        filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
        self.register_buffer('g0_col', filts[0])
        self.register_buffer('g1_col', filts[1])
        self.register_buffer('g0_row', filts[2])
        self.register_buffer('g1_row', filts[3])

    def forward(self, coeffs):
        yl, yh = coeffs
        ll = yl
        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device)
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[...,:-1,:]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[...,:-1]
            ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row)
        return ll

class IWPT2D(torch.nn.Module):
    def __init__(self, iwt=DWT2DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt  = iwt
        self.J = J
    def synthesis_one_level(self,X):
        X = einops.rearrange(X, 'b (c f) h w -> b c f h w', f=4)
        L, H = torch.split(X, [1, 3], dim=2)
        L = L.squeeze(2)
        H = [H]
        y = self.iwt((L, H))
        return y
    def wavelet_synthesis(self,x,J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x
    def forward(self, x):
        return self.wavelet_synthesis(x,J=self.J)

In [6]:
x2d = torch.randn(2, 3, 64, 64)
wt2d = DWT2DForward(wave='bior4.4')
wpt2d = WPT2D(wt=wt2d, J=3)
iwt2d = DWT2DInverse(wave='bior4.4')
iwpt2d = IWPT2D(iwt=iwt2d, J=3)
with torch.no_grad():
    X2d = wpt2d(x2d)
    xhat2d = iwpt2d(X2d)
assert (xhat2d - x2d).abs().max() < 1e-5

In [7]:
def prep_filt_afb3d(h0_x, h1_x,
                    h0_y=None, h1_y=None,
                    h0_z=None, h1_z=None,
                    device=None):
    # If not provided, default Y/Z filters to the X filters
    if h0_y is None or h1_y is None:
        h0_y, h1_y = h0_x, h1_x
    if h0_z is None or h1_z is None:
        h0_z, h1_z = h0_x, h1_x
    # Prepare 1D filters for each dimension
    h0_x, h1_x = prep_filt_afb1d(h0_x, h1_x, device=device)
    h0_y, h1_y = prep_filt_afb1d(h0_y, h1_y, device=device)
    h0_z, h1_z = prep_filt_afb1d(h0_z, h1_z, device=device)
    h0_x = h0_x.reshape(1, 1, -1, 1, 1)
    h1_x = h1_x.reshape(1, 1, -1, 1, 1)
    h0_y = h0_y.reshape(1, 1, 1, -1, 1)
    h1_y = h1_y.reshape(1, 1, 1, -1, 1)
    h0_z = h0_z.reshape(1, 1, 1, 1, -1)
    h1_z = h1_z.reshape(1, 1, 1, 1, -1)
    return h0_x, h1_x, h0_y, h1_y, h0_z, h1_z

def afb3d(x, filts):
    # Unpack filters
    if len(filts) == 2:
        h0 = filts[0]
        h1 = filts[1]
        h0_x, h1_x = h0, h1
        h0_y, h1_y = h0, h1
        h0_z, h1_z = h0, h1
    elif len(filts) == 6:
        h0_x, h1_x, h0_y, h1_y, h0_z, h1_z = filts
    else:
        raise ValueError("Unknown form for input filts; expected length 2 or 6.")

    # Helper: apply the 1D analysis filter bank (afb1d) along one axis of a 5D tensor.
    # It reshapes the input into 4D, calls afb1d (which uses conv2d), then restores the shape.
    def _afb1d_along_axis(x, h0, h1, axis):
        # x: (B, C, D, H, W)
        B, C, D, H, W = x.shape
        if axis == 4:  # along z-axis (width)
            # Merge D and H into the batch dimension.
            x_reshaped = x.reshape(B * D * H, C, W)            # (B*D*H, C, W)
            x_reshaped = x_reshaped.unsqueeze(2)                # (B*D*H, C, 1, W)
            # Apply 1D filter along last dimension (dim=3) using afb1d.
            out = afb1d(x_reshaped, h0, h1, dim=3)              # (B*D*H, 2*C, 1, new_W)
            new_W = out.shape[-1]
            out = out.squeeze(2)                                # (B*D*H, 2*C, new_W)
            # Reshape back to (B, 2*C, D, H, new_W)
            out = out.reshape(B, D, H, 2 * C, new_W).permute(0, 3, 1, 2, 4)
            return out
        elif axis == 3:  # along y-axis (height)
            # Merge D and W into the batch dimension.
            x_reshaped = x.reshape(B * D * W, C, H)            # (B*D*W, C, H)
            x_reshaped = x_reshaped.unsqueeze(2)               # (B*D*W, C, 1, H)
            out = afb1d(x_reshaped, h0, h1, dim=3)             # (B*D*W, 2*C, 1, new_H)
            new_H = out.shape[-1]
            out = out.squeeze(2)                               # (B*D*W, 2*C, new_H)
            # Reshape back to (B, 2*C, D, new_H, W)
            out = out.reshape(B, D, W, 2 * C, new_H).permute(0, 3, 1, 4, 2)
            return out
        elif axis == 2:  # along x-axis (depth)
            # Merge H and W into the batch dimension.
            x_reshaped = x.reshape(B * H * W, C, D)            # (B*H*W, C, D)
            x_reshaped = x_reshaped.unsqueeze(2)               # (B*H*W, C, 1, D)
            out = afb1d(x_reshaped, h0, h1, dim=3)             # (B*H*W, 2*C, 1, new_D)
            new_D = out.shape[-1]
            out = out.squeeze(2)                               # (B*H*W, 2*C, new_D)
            # Reshape back to (B, 2*C, new_D, H, W)
            out = out.reshape(B, H, W, 2 * C, new_D).permute(0, 3, 4, 1, 2)
            return out
        else:
            raise ValueError("Axis must be 2, 3, or 4 for 3D input.")

    # Sequentially apply the 1D filter banks along z, then y, then x.
    out_z = _afb1d_along_axis(x, h0_z, h1_z, axis=4)  # -> shape: (B, 2*C, D, H, W_z)
    out_y = _afb1d_along_axis(out_z, h0_y, h1_y, axis=3)  # -> shape: (B, 4*C, D, H_y, W_z)
    out_x = _afb1d_along_axis(out_y, h0_x, h1_x, axis=2)  # -> shape: (B, 8*C, D_x, H_y, W_z)

    # Group the channel dimension to separate the eight subbands.
    B, ch8, new_D, new_H, new_W = out_x.shape
    C_orig = ch8 // 8
    out_x = out_x.reshape(B, C_orig, 8, new_D, new_H, new_W)
    # The 0th index is the lowpass (approximation) subband; the remaining 7 are details.
    low = out_x[:, :, 0, :, :, :]
    highs = out_x[:, :, 1:, :, :, :]
    return low, highs

def prep_filt_afb3d(h0_x, h1_x,
                    h0_y=None, h1_y=None,
                    h0_z=None, h1_z=None,
                    device=None):
    # If not provided, default Y/Z filters to X filters
    if h0_y is None or h1_y is None:
        h0_y, h1_y = h0_x, h1_x
    if h0_z is None or h1_z is None:
        h0_z, h1_z = h0_x, h1_x

    # Prepare them as 1D filters
    h0_x, h1_x = prep_filt_afb1d(h0_x, h1_x, device=device)
    h0_y, h1_y = prep_filt_afb1d(h0_y, h1_y, device=device)
    h0_z, h1_z = prep_filt_afb1d(h0_z, h1_z, device=device)

    # Reshape into (1,1,...) for conv usage
    # X-filters will be convolved along the depth (dim=2),
    # Y-filters along height (dim=3), Z-filters along width (dim=4).
    h0_x = h0_x.reshape(1, 1, -1, 1, 1)
    h1_x = h1_x.reshape(1, 1, -1, 1, 1)

    h0_y = h0_y.reshape(1, 1, 1, -1, 1)
    h1_y = h1_y.reshape(1, 1, 1, -1, 1)

    h0_z = h0_z.reshape(1, 1, 1, 1, -1)
    h1_z = h1_z.reshape(1, 1, 1, 1, -1)

    return h0_x, h1_x, h0_y, h1_y, h0_z, h1_z


def afb3d(x, filts):

    # Unpack filters
    if len(filts) == 2:
        # same filters for all dimensions
        h0 = filts[0]
        h1 = filts[1]
        h0_x, h1_x = h0, h1
        h0_y, h1_y = h0, h1
        h0_z, h1_z = h0, h1
    elif len(filts) == 6:
        h0_x, h1_x, h0_y, h1_y, h0_z, h1_z = filts
    else:
        raise ValueError("filts must be either length 2 or 6 for AFB3D.")

    B, C, D, H, W = x.shape

    # -- Helper: apply AFB1D along a particular axis of a 5D tensor --
    def _afb1d_along_axis(x_5d, h0, h1, axis):
        Bx, Cx, Dx, Hx, Wx = x_5d.shape
        if axis == 4:
            # axis = width
            # merge (B, D, H) into the batch for the 1D call
            x_reshaped = x_5d.permute(0, 2, 3, 1, 4)  # (B, D, H, C, W)
            x_reshaped = x_reshaped.reshape(Bx * Dx * Hx, Cx, Wx)  # (batch, C, L)
            x_reshaped = x_reshaped[:, :, None, :]                 # (batch, C, 1, W)

            lohi = afb1d(x_reshaped, h0, h1, dim=3)  # returns (batch, 2*C, 1, newW)
            # shape out: (batch, 2*C, 1, newW)
            outC = lohi.shape[1]
            outW = lohi.shape[-1]
            lohi = lohi.squeeze(2)  # (batch, 2*C, newW)

            # reshape back to (B, 2*C, D, H, W')
            lohi = lohi.reshape(Bx, Dx, Hx, outC, outW)
            lohi = lohi.permute(0, 3, 1, 2, 4)  # (B, outC, D, H, W')
            return lohi

        elif axis == 3:
            # axis = height
            x_reshaped = x_5d.permute(0, 2, 4, 1, 3)  # (B, D, W, C, H)
            x_reshaped = x_reshaped.reshape(Bx * Dx * Wx, Cx, Hx)
            x_reshaped = x_reshaped[:, :, None, :]  # (batch, C, 1, H)

            lohi = afb1d(x_reshaped, h0, h1, dim=3)  # (batch, 2*C, 1, newH)
            outC = lohi.shape[1]
            outH = lohi.shape[-1]
            lohi = lohi.squeeze(2)  # (batch, 2*C, newH)

            # reshape back
            lohi = lohi.reshape(Bx, Dx, Wx, outC, outH)
            lohi = lohi.permute(0, 3, 1, 4, 2)  # (B, outC, D, newH, W)
            return lohi

        elif axis == 2:
            # axis = depth
            x_reshaped = x_5d.permute(0, 3, 4, 1, 2)  # (B, H, W, C, D)
            x_reshaped = x_reshaped.reshape(Bx * Hx * Wx, Cx, Dx)
            x_reshaped = x_reshaped[:, :, None, :]  # (batch, C, 1, D)

            lohi = afb1d(x_reshaped, h0, h1, dim=3)  # (batch, 2*C, 1, newD)
            outC = lohi.shape[1]
            outD = lohi.shape[-1]
            lohi = lohi.squeeze(2)  # (batch, 2*C, newD)

            # reshape back
            lohi = lohi.reshape(Bx, Hx, Wx, outC, outD)
            lohi = lohi.permute(0, 3, 4, 1, 2)  # (B, outC, newD, H, W)
            return lohi
        else:
            raise ValueError("Axis must be one of (2,3,4) for a (B,C,D,H,W) tensor.")

    # 1) Filter along Z (width=dim=4)
    out_z = _afb1d_along_axis(x, h0_z, h1_z, axis=4)  # shape: (B, 2C, D, H, Wz)
    # 2) Filter along Y (height=dim=3)
    out_y = _afb1d_along_axis(out_z, h0_y, h1_y, axis=3)  # shape: (B, 4C, D, Hy, Wz)
    # 3) Filter along X (depth=dim=2)
    out_x = _afb1d_along_axis(out_y, h0_x, h1_x, axis=2)  # shape: (B, 8C, Dx, Hy, Wz)

    # Now separate the 8 subbands in the channel dimension:
    B2, ch8, Dx, Hy, Wz = out_x.shape
    assert (ch8 % 8) == 0, "Channel dimension must be multiple of 8."
    C2 = ch8 // 8

    # Reshape to (B, C2, 8, D', H', W')
    out_x = out_x.view(B2, C2, 8, Dx, Hy, Wz)

    # The [0]-th subband is the lowpass
    low = out_x[:, :, 0, :, :, :]    # (B, C2, Dx, Hy, Wz)
    # The other [1..7] are the highpass
    highs = out_x[:, :, 1:, :, :, :] # (B, C2, 7, Dx, Hy, Wz)

    return low, highs


class AFB3D(Function):
    @staticmethod
    def forward(ctx, x, h0_x, h1_x, h0_y, h1_y, h0_z, h1_z):
        ctx.save_for_backward(h0_x, h1_x, h0_y, h1_y, h0_z, h1_z)
        # Save original D,H,W for potential odd-size correction in backward
        ctx.original_shape = x.shape[-3:]

        # Perform forward 3D decomposition
        low, highs = afb3d(x, (h0_x, h1_x, h0_y, h1_y, h0_z, h1_z))
        return low, highs

    @staticmethod
    def backward(ctx, dlow, dhigh):
        dx = None
        if ctx.needs_input_grad[0]:
            # Retrieve saved filters
            h0_x, h1_x, h0_y, h1_y, h0_z, h1_z = ctx.saved_tensors
            # Call your 3D synthesis filter bank (you must implement sfb3d(...) similarly!)
            dx = sfb3d(dlow, dhigh, (h0_x, h1_x, h0_y, h1_y, h0_z, h1_z))
            
            # If the original D,H,W were odd, dx might be one sample longer
            D, H, W = ctx.original_shape
            if dx.shape[-3] > D:
                dx = dx[..., :D, :, :]
            if dx.shape[-2] > H:
                dx = dx[..., :H, :]
            if dx.shape[-1] > W:
                dx = dx[..., :W]

        # The rest of the returned gradients (for the filters) are None
        return dx, None, None, None, None, None, None

def prep_filt_sfb3d(g0_x, g1_x,
                    g0_y=None, g1_y=None,
                    g0_z=None, g1_z=None,
                    device=None):
    # If Y or Z filters not provided, default to the X filters
    if g0_y is None or g1_y is None:
        g0_y, g1_y = g0_x, g1_x
    if g0_z is None or g1_z is None:
        g0_z, g1_z = g0_x, g1_x

    # Prep them as standard 1D synthesis filters
    g0_x, g1_x = prep_filt_sfb1d(g0_x, g1_x, device=device)
    g0_y, g1_y = prep_filt_sfb1d(g0_y, g1_y, device=device)
    g0_z, g1_z = prep_filt_sfb1d(g0_z, g1_z, device=device)

    # Reshape for use with 3D transposed convolutions:
    #  - X filters: (1,1,L,1,1)
    #  - Y filters: (1,1,1,L,1)
    #  - Z filters: (1,1,1,1,L)
    g0_x = g0_x.reshape(1, 1, -1, 1, 1)
    g1_x = g1_x.reshape(1, 1, -1, 1, 1)

    g0_y = g0_y.reshape(1, 1, 1, -1, 1)
    g1_y = g1_y.reshape(1, 1, 1, -1, 1)

    g0_z = g0_z.reshape(1, 1, 1, 1, -1)
    g1_z = g1_z.reshape(1, 1, 1, 1, -1)

    return g0_x, g1_x, g0_y, g1_y, g0_z, g1_z

def _sfb1d_along_axis(x_5d, g0, g1, axis):
    Bx, Cx, Dx, Hx, Wx = x_5d.shape
    # The channel count must be even since it represents concatenated low/high bands.
    assert Cx % 2 == 0, "Channel dimension must be even for synthesis along an axis."
    
    if axis == 4:
        # Synthesis along width.
        # Permute so that width becomes the last dimension in a (B, D, H, C, W) tensor.
        x_perm = x_5d.permute(0, 2, 3, 1, 4).contiguous()  # (B, D, H, C, W)
        x_reshaped = x_perm.view(Bx * Dx * Hx, Cx, Wx)       # (B*D*H, C, W)
        
    elif axis == 3:
        # Synthesis along height.
        # Permute so that height is last: (B, D, W, C, H)
        x_perm = x_5d.permute(0, 2, 4, 1, 3).contiguous()    # (B, D, W, C, H)
        x_reshaped = x_perm.view(Bx * Dx * Wx, Cx, Hx)       # (B*D*W, C, H)
        
    elif axis == 2:
        # Synthesis along depth.
        # Permute so that depth is last: (B, H, W, C, D)
        x_perm = x_5d.permute(0, 3, 4, 1, 2).contiguous()    # (B, H, W, C, D)
        x_reshaped = x_perm.view(Bx * Hx * Wx, Cx, Dx)       # (B*H*W, C, D)
    else:
        raise ValueError("Axis must be one of 2 (depth), 3 (height), or 4 (width).")
    
    # Split the channel dimension into two halves: low and high parts.
    lo = x_reshaped[:, :Cx // 2, :]
    hi = x_reshaped[:, Cx // 2:, :]
    # sfb1d expects inputs of shape (N, C, 1, L), so add a singleton dimension.
    lo = lo.unsqueeze(2)  # shape: (N, C//2, 1, L)
    hi = hi.unsqueeze(2)
    # Apply 1D synthesis along the last dimension.
    y = sfb1d(lo, hi, g0, g1, dim=3)  # (N, C//2, 1, new_L)
    y = y.squeeze(2)  # now shape: (N, C//2, new_L)
    
    new_L = y.shape[-1]
    # Reshape and invert the permutation to return to 5D.
    if axis == 4:
        y = y.view(Bx, Dx, Hx, Cx // 2, new_L)   # (B, D, H, C//2, new_W)
        y = y.permute(0, 3, 1, 2, 4).contiguous()   # (B, C//2, D, H, new_W)
    elif axis == 3:
        y = y.view(Bx, Dx, Wx, Cx // 2, new_L)       # (B, D, W, C//2, new_H)
        y = y.permute(0, 3, 1, 4, 2).contiguous()      # (B, C//2, D, new_H, W)
    elif axis == 2:
        y = y.view(Bx, Hx, Wx, Cx // 2, new_L)       # (B, H, W, C//2, new_D)
        y = y.permute(0, 3, 4, 1, 2).contiguous()      # (B, C//2, new_D, H, W)
    return y

def sfb3d(low, highs, filts):
    # Combine the subbands along a new subband axis.
    # low is the 0th subband; highs are the remaining 7.
    Y = torch.cat([low.unsqueeze(2), highs], dim=2)  # shape: (B, C, 8, D, H, W)
    B, C, eight, D, H, W = Y.shape
    # Collapse the subband axis into the channel dimension.
    Y = Y.view(B, C * eight, D, H, W)  # shape: (B, 8C, D, H, W)
    
    # Unpack synthesis filters.
    g0_x, g1_x, g0_y, g1_y, g0_z, g1_z = filts
    
    # Reverse the analysis steps in the opposite order.
    # Analysis was: first along width (axis 4), then height (axis 3), then depth (axis 2).
    # Therefore, synthesis is performed in the order: along depth, then height, then width.
    Y = _sfb1d_along_axis(Y, g0_x, g1_x, axis=2)  # Combine subbands along depth.
    Y = _sfb1d_along_axis(Y, g0_y, g1_y, axis=3)  # Then along height.
    Y = _sfb1d_along_axis(Y, g0_z, g1_z, axis=4)  # Finally along width.
    
    return Y

class SFB3D(Function):
    @staticmethod
    def forward(ctx, low, highs, g0_x, g1_x, g0_y, g1_y, g0_z, g1_z):
        ctx.save_for_backward(g0_x, g1_x, g0_y, g1_y, g0_z, g1_z)
        y = sfb3d(low, highs, (g0_x, g1_x, g0_y, g1_y, g0_z, g1_z))
        return y

    @staticmethod
    def backward(ctx, dy):
        g0_x, g1_x, g0_y, g1_y, g0_z, g1_z = ctx.saved_tensors
        dlow, dhigh = afb3d(dy, (g0_x, g1_x, g0_y, g1_y, g0_z, g1_z))
        return dlow, dhigh, None, None, None, None, None, None

class DWT3DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        # Process the wave parameter
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            # Use the same decomposition filters for all dimensions.
            h0_x, h1_x = wave.dec_lo, wave.dec_hi
            h0_y, h1_y = h0_x, h1_x
            h0_z, h1_z = h0_x, h1_x
        else:
            # Allow either a 2-tuple or a 6-tuple.
            if len(wave) == 2:
                h0_x, h1_x = wave[0], wave[1]
                h0_y, h1_y = h0_x, h1_x
                h0_z, h1_z = h0_x, h1_x
            elif len(wave) == 6:
                h0_x, h1_x = wave[0], wave[1]
                h0_y, h1_y = wave[2], wave[3]
                h0_z, h1_z = wave[4], wave[5]
            else:
                raise ValueError("wave must be either a 2-tuple or a 6-tuple of filters.")
        
        # Preprocess filters for the 3D analysis bank.
        filts = prep_filt_afb3d(h0_x, h1_x, h0_y, h1_y, h0_z, h1_z)
        self.register_buffer('h0_x', filts[0])
        self.register_buffer('h1_x', filts[1])
        self.register_buffer('h0_y', filts[2])
        self.register_buffer('h1_y', filts[3])
        self.register_buffer('h0_z', filts[4])
        self.register_buffer('h1_z', filts[5])
        self.J = J

    def forward(self, x):
        assert x.ndim == 5, "DWT3DForward expects a 5D input (B, C, D, H, W)"
        highs = []
        ll = x
        for j in range(self.J):
            # AFB3D.apply returns (low, highs) where highs is a tensor containing the 7 detail subbands.
            ll, high = AFB3D.apply(ll, self.h0_x, self.h1_x,
                                        self.h0_y, self.h1_y,
                                        self.h0_z, self.h1_z)
            highs.append(high)
        return ll, highs

class WPT3D(torch.nn.Module):
    def __init__(self, wt=DWT3DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt = wt
        self.J = J

    def analysis_one_level(self, x):
        # Perform one level of 3D DWT.
        # L has shape (B, C, D, H, W)
        # H is a list where H[0] has shape (B, C, 7, D, H, W)
        L, H = self.wt(x)
        # Create a new subband axis:
        # Unsqueeze L to shape (B, C, 1, D, H, W) and then concatenate with the detail subbands.
        X = torch.cat([L.unsqueeze(2), H[0]], dim=2)  # now X has shape (B, C, 8, D, H, W)
        # Merge the channel and subband dimensions for a complete wavelet packet representation.
        X = einops.rearrange(X, 'b c f d h w -> b (c f) d h w')
        return X

    def wavelet_analysis(self, x, J):
        # Recursively apply one-level analysis J times.
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x

    def forward(self, x):
        # Expect x to be a 5D tensor: (B, C, D, H, W)
        return self.wavelet_analysis(x, J=self.J)

class DWT3DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        # If wave is given as a string, convert to a pywt.Wavelet.
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        # Determine synthesis filters.
        if isinstance(wave, pywt.Wavelet):
            # Use the same reconstruction filters for all dimensions.
            g0_x, g1_x = wave.rec_lo, wave.rec_hi
            g0_y, g1_y = g0_x, g1_x
            g0_z, g1_z = g0_x, g1_x
        else:
            # Otherwise, expect either a 2-tuple or a 6-tuple.
            if len(wave) == 2:
                g0_x, g1_x = wave[0], wave[1]
                g0_y, g1_y = g0_x, g1_x
                g0_z, g1_z = g0_x, g1_x
            elif len(wave) == 6:
                g0_x, g1_x = wave[0], wave[1]
                g0_y, g1_y = wave[2], wave[3]
                g0_z, g1_z = wave[4], wave[5]
            else:
                raise ValueError("wave must be either a 2-tuple or a 6-tuple of filters.")
        # Preprocess the synthesis filters.
        filts = prep_filt_sfb3d(g0_x, g1_x, g0_y, g1_y, g0_z, g1_z)
        self.register_buffer('g0_x', filts[0])
        self.register_buffer('g1_x', filts[1])
        self.register_buffer('g0_y', filts[2])
        self.register_buffer('g1_y', filts[3])
        self.register_buffer('g0_z', filts[4])
        self.register_buffer('g1_z', filts[5])

    def forward(self, coeffs):
        ll, highs = coeffs
        # Iterate in reverse order over the detail coefficients.
        for h in highs[::-1]:
            # If a detail subband is missing, replace with zeros.
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 7,
                                ll.shape[-3], ll.shape[-2], ll.shape[-1],
                                device=ll.device, dtype=ll.dtype)
            # If the lowpass tensor has an extra sample (due to odd sizes), crop it.
            if ll.shape[-3] > h.shape[-3]:
                ll = ll[..., :-1, :, :]
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[..., :, :-1, :]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[..., :, :, :-1]
            # Synthesize the current level using the 3D synthesis filter bank.
            ll = SFB3D.apply(ll, h,
                             self.g0_x, self.g1_x,
                             self.g0_y, self.g1_y,
                             self.g0_z, self.g1_z)
        return ll

class IWPT3D(torch.nn.Module):
    def __init__(self, iwt=DWT3DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt = iwt
        self.J = J

    def synthesis_one_level(self, X):
        # Rearrange to separate the subband dimension: (B, C, 8, D, H, W)
        X = einops.rearrange(X, 'b (c f) d h w -> b c f d h w', f=8)
        # Split into lowpass (first subband) and highpass (remaining 7 subbands)
        L, H = torch.split(X, [1, 7], dim=2)
        # Remove the subband dimension from the lowpass
        L = L.squeeze(2)  # now (B, C, D, H, W)
        H = H.squeeze(2)  # now (B, C, 7, D, H, W)
        # Wrap the detail coefficients in a list (as expected by DWT3DInverse)
        y = self.iwt((L, [H]))
        return y

    def wavelet_synthesis(self, x, J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_synthesis(x, J=self.J)

In [8]:
x3d = torch.randn(2, 3, 16, 16, 16)
wt3d = DWT3DForward(wave='bior4.4')
wpt3d = WPT3D(wt=wt3d, J=3)
iwt3d = DWT3DInverse(wave='bior4.4')
iwpt3d = IWPT3D(iwt=iwt3d, J=3)
with torch.no_grad():
    X3d = wpt3d(x3d)
    xhat3d = iwpt3d(X3d)
assert (xhat3d - x3d).abs().max() < 1e-5

AssertionError: 

In [12]:
(xhat3d-x3d).std()

tensor(1.1802)