In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pywt
import einops
from torch.autograd import Function

In [2]:
def roll(x, n, dim, make_even=False):
    if n < 0:
        n = x.shape[dim] + n

    if make_even and x.shape[dim] % 2 == 1:
        end = 1
    else:
        end = 0

    if dim == 0:
        return torch.cat((x[-n:], x[:-n+end]), dim=0)
    elif dim == 1:
        return torch.cat((x[:,-n:], x[:,:-n+end]), dim=1)
    elif dim == 2 or dim == -2:
        return torch.cat((x[:,:,-n:], x[:,:,:-n+end]), dim=2)
    elif dim == 3 or dim == -1:
        return torch.cat((x[:,:,:,-n:], x[:,:,:,:-n+end]), dim=3)

In [3]:
def prep_filt_afb1d(h0, h1, device=None):
    h0 = np.array(h0[::-1]).ravel()
    h1 = np.array(h1[::-1]).ravel()
    t = torch.get_default_dtype()
    h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1))
    h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1))
    return h0, h1

def afb1d(x, h0, h1, dim=-1):
    C = x.shape[1]
    # Convert the dim to positive
    d = dim % 4
    s = (2, 1) if d == 2 else (1, 2)
    N = x.shape[d]
    # If h0, h1 are not tensors, make them. If they are, then assume that they
    # are in the right order
    if not isinstance(h0, torch.Tensor):
        h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    if not isinstance(h1, torch.Tensor):
        h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
                          dtype=torch.float, device=x.device)
    L = h0.numel()
    L2 = L // 2
    shape = [1,1,1,1]
    shape[d] = L
    # If h aren't in the right shape, make them so
    if h0.shape != tuple(shape):
        h0 = h0.reshape(*shape)
    if h1.shape != tuple(shape):
        h1 = h1.reshape(*shape)
    h = torch.cat([h0, h1] * C, dim=0)

    if x.shape[dim] % 2 == 1:
        if d == 2:
            x = torch.cat((x, x[:,:,-1:]), dim=2)
        else:
            x = torch.cat((x, x[:,:,:,-1:]), dim=3)
        N += 1
    x = roll(x, -L2, dim=d)
    pad = (L-1, 0) if d == 2 else (0, L-1)
    lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
    N2 = N//2
    if d == 2:
        lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2]
        lohi = lohi[:,:,:N2]
    else:
        lohi[:,:,:,:L2] = lohi[:,:,:,:L2] + lohi[:,:,:,N2:N2+L2]
        lohi = lohi[:,:,:,:N2]

    return lohi
        
class AFB1D(Function):
    @staticmethod
    def forward(ctx, x, h0, h1):

        # Make inputs 4d
        x = x[:, :, None, :]
        h0 = h0[:, :, None, :]
        h1 = h1[:, :, None, :]

        # Save for backwards
        ctx.save_for_backward(h0, h1)
        ctx.shape = x.shape[3]

        lohi = afb1d(x, h0, h1, dim=3)
        x0 = lohi[:, ::2, 0].contiguous()
        x1 = lohi[:, 1::2, 0].contiguous()
        return x0, x1

    @staticmethod
    def backward(ctx, dx0, dx1):
        dx = None
        if ctx.needs_input_grad[0]:
            h0, h1 = ctx.saved_tensors

            # Make grads 4d
            dx0 = dx0[:, :, None, :]
            dx1 = dx1[:, :, None, :]

            dx = sfb1d(dx0, dx1, h0, h1, dim=3)[:, :, 0]

            # Check for odd input
            if dx.shape[2] > ctx.shape:
                dx = dx[:, :, :ctx.shape]

        return dx, None, None, None, None, None

def prep_filt_sfb1d(g0, g1, device=None):
    g0 = np.array(g0).ravel()
    g1 = np.array(g1).ravel()
    t = torch.get_default_dtype()
    g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
    g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))

    return g0, g1

def sfb1d(lo, hi, g0, g1, dim=-1):
    C = lo.shape[1]
    d = dim % 4
    if not isinstance(g0, torch.Tensor):
        g0 = torch.tensor(np.copy(np.array(g0).ravel()),
                          dtype=torch.float, device=lo.device)
    if not isinstance(g1, torch.Tensor):
        g1 = torch.tensor(np.copy(np.array(g1).ravel()),
                          dtype=torch.float, device=lo.device)
    L = g0.numel()
    shape = [1,1,1,1]
    shape[d] = L
    N = 2*lo.shape[d]
    # If g aren't in the right shape, make them so
    if g0.shape != tuple(shape):
        g0 = g0.reshape(*shape)
    if g1.shape != tuple(shape):
        g1 = g1.reshape(*shape)

    s = (2, 1) if d == 2 else (1,2)
    g0 = torch.cat([g0]*C,dim=0)
    g1 = torch.cat([g1]*C,dim=0)
    y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
        F.conv_transpose2d(hi, g1, stride=s, groups=C)
    if d == 2:
        y[:,:,:L-2] = y[:,:,:L-2] + y[:,:,N:N+L-2]
        y = y[:,:,:N]
    else:
        y[:,:,:,:L-2] = y[:,:,:,:L-2] + y[:,:,:,N:N+L-2]
        y = y[:,:,:,:N]
    y = roll(y, 1-L//2, dim=dim)

    return y

class SFB1D(Function):
    @staticmethod
    def forward(ctx, low, high, g0, g1):
        # Make into a 2d tensor with 1 row
        low = low[:, :, None, :]
        high = high[:, :, None, :]
        g0 = g0[:, :, None, :]
        g1 = g1[:, :, None, :]

        ctx.save_for_backward(g0, g1)

        return sfb1d(low, high, g0, g1, dim=3)[:, :, 0]

    @staticmethod
    def backward(ctx, dy):
        dlow, dhigh = None, None
        if ctx.needs_input_grad[0]:
            g0, g1, = ctx.saved_tensors
            dy = dy[:, :, None, :]

            dx = afb1d(dy, g0, g1, dim=3)

            dlow = dx[:, ::2, 0].contiguous()
            dhigh = dx[:, 1::2, 0].contiguous()
        return dlow, dhigh, None, None, None, None, None

class DWT1DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0, h1 = wave.dec_lo, wave.dec_hi
        else:
            assert len(wave) == 2
            h0, h1 = wave[0], wave[1]
        filts = prep_filt_afb1d(h0, h1)
        self.register_buffer('h0', filts[0])
        self.register_buffer('h1', filts[1])
        self.J = J

    def forward(self, x):
        assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        highs = []
        x0 = x
        for j in range(self.J):
            x0, x1 = AFB1D.apply(x0, self.h0, self.h1)
            highs.append(x1)
        return x0, highs
        
class WPT1D(torch.nn.Module):
    def __init__(self, wt=DWT1DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt = wt
        self.J = J

    def analysis_one_level(self,x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2),H[0].unsqueeze(2)],dim=2)
        X = einops.rearrange(X, 'b c f ℓ -> b (c f) ℓ')
        return X

    def wavelet_analysis(self, x, J):
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_analysis(x, J=self.J)
        
class DWT1DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0, g1 = wave.rec_lo, wave.rec_hi
        else:
            assert len(wave) == 2
            g0, g1 = wave[0], wave[1]
        filts = prep_filt_sfb1d(g0, g1)
        self.register_buffer('g0', filts[0])
        self.register_buffer('g1', filts[1])

    def forward(self, coeffs):
        x0, highs = coeffs
        assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        for x1 in highs[::-1]:
            if x1 is None:
                x1 = torch.zeros_like(x0)
            if x0.shape[-1] > x1.shape[-1]:
                x0 = x0[..., :-1]
            x0 = SFB1D.apply(x0, x1, self.g0, self.g1)
        return x0

class IWPT1D(torch.nn.Module):
    def __init__(self, iwt=DWT1DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt = iwt
        self.J = J

    def synthesis_one_level(self, X):
        X = einops.rearrange(X, 'b (c f) ℓ -> b c f ℓ', f=2)
        L, H = torch.split(X, [1, 1], dim=2)
        L = L.squeeze(2)
        H = [H.squeeze(2)]
        y = self.iwt((L, H))
        return y

    def wavelet_synthesis(self, x, J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_synthesis(x, J=self.J)

In [4]:
x1d = torch.randn(2, 3, 4096)
wt1d = DWT1DForward(wave='bior4.4')
wpt1d = WPT1D(wt=wt1d, J=3)
iwt1d = DWT1DInverse(wave='bior4.4')
iwpt1d = IWPT1D(iwt=iwt1d, J=3)
with torch.no_grad():
    X1d = wpt1d(x1d)
    xhat1d = iwpt1d(X1d)
assert (xhat1d - x1d).abs().max() < 1e-5

In [5]:
def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None):
    h0_col, h1_col = prep_filt_afb1d(h0_col, h1_col, device)
    if h0_row is None:
        h0_row, h1_row = h0_col, h1_col
    else:
        h0_row, h1_row = prep_filt_afb1d(h0_row, h1_row, device)

    h0_col = h0_col.reshape((1, 1, -1, 1))
    h1_col = h1_col.reshape((1, 1, -1, 1))
    h0_row = h0_row.reshape((1, 1, 1, -1))
    h1_row = h1_row.reshape((1, 1, 1, -1))
    return h0_col, h1_col, h0_row, h1_row

def afb2d(x, filts):
    tensorize = [not isinstance(f, torch.Tensor) for f in filts]
    if len(filts) == 2:
        h0, h1 = filts
        if True in tensorize:
            h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
                h0, h1, device=x.device)
        else:
            h0_col = h0
            h0_row = h0.transpose(2,3)
            h1_col = h1
            h1_row = h1.transpose(2,3)
    elif len(filts) == 4:
        if True in tensorize:
            h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d(
                *filts, device=x.device)
        else:
            h0_col, h1_col, h0_row, h1_row = filts
    else:
        raise ValueError("Unknown form for input filts")

    lohi = afb1d(x, h0_row, h1_row, dim=3)
    y = afb1d(lohi, h0_col, h1_col, dim=2)

    return y

class AFB2D(Function):
    @staticmethod
    def forward(ctx, x, h0_row, h1_row, h0_col, h1_col):
        ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
        ctx.shape = x.shape[-2:]
        lohi = afb1d(x, h0_row, h1_row, dim=3)
        y = afb1d(lohi, h0_col, h1_col, dim=2)
        s = y.shape
        y = y.reshape(s[0], -1, 4, s[-2], s[-1])
        low = y[:,:,0].contiguous()
        highs = y[:,:,1:].contiguous()
        return low, highs

    @staticmethod
    def backward(ctx, low, highs):
        dx = None
        if ctx.needs_input_grad[0]:
            h0_row, h1_row, h0_col, h1_col = ctx.saved_tensors
            lh, hl, hh = torch.unbind(highs, dim=2)
            lo = sfb1d(low, lh, h0_col, h1_col, dim=2)
            hi = sfb1d(hl, hh, h0_col, h1_col, dim=2)
            dx = sfb1d(lo, hi, h0_row, h1_row, dim=3)
            if dx.shape[-2] > ctx.shape[-2] and dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:ctx.shape[-2], :ctx.shape[-1]]
            elif dx.shape[-2] > ctx.shape[-2]:
                dx = dx[:,:,:ctx.shape[-2]]
            elif dx.shape[-1] > ctx.shape[-1]:
                dx = dx[:,:,:,:ctx.shape[-1]]
        return dx, None, None, None, None, None

def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
    g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
    if g0_row is None:
        g0_row, g1_row = g0_col, g1_col
    else:
        g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)

    g0_col = g0_col.reshape((1, 1, -1, 1))
    g1_col = g1_col.reshape((1, 1, -1, 1))
    g0_row = g0_row.reshape((1, 1, 1, -1))
    g1_row = g1_row.reshape((1, 1, 1, -1))

    return g0_col, g1_col, g0_row, g1_row

def sfb2d(ll, lh, hl, hh, filts):
    tensorize = [not isinstance(x, torch.Tensor) for x in filts]
    if len(filts) == 2:
        g0, g1 = filts
        if True in tensorize:
            g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1)
        else:
            g0_col = g0
            g0_row = g0.transpose(2,3)
            g1_col = g1
            g1_row = g1.transpose(2,3)
    elif len(filts) == 4:
        if True in tensorize:
            g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts)
        else:
            g0_col, g1_col, g0_row, g1_row = filts
    else:
        raise ValueError("Unknown form for input filts")

    lo = sfb1d(ll, lh, g0_col, g1_col, dim=2)
    hi = sfb1d(hl, hh, g0_col, g1_col, dim=2)
    y = sfb1d(lo, hi, g0_row, g1_row, dim=3)

    return y
        
class SFB2D(Function):
    @staticmethod
    def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col):
        ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)

        lh, hl, hh = torch.unbind(highs, dim=2)
        lo = sfb1d(low, lh, g0_col, g1_col, dim=2)
        hi = sfb1d(hl, hh, g0_col, g1_col, dim=2)
        y = sfb1d(lo, hi, g0_row, g1_row, dim=3)
        return y

    @staticmethod
    def backward(ctx, dy):
        dlow, dhigh = None, None
        if ctx.needs_input_grad[0]:
            g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
            dx = afb1d(dy, g0_row, g1_row, dim=3)
            dx = afb1d(dx, g0_col, g1_col, dim=2)
            s = dx.shape
            dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
            dlow = dx[:,:,0].contiguous()
            dhigh = dx[:,:,1:].contiguous()
        return dlow, dhigh, None, None, None, None, None

class DWT2DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]
        filts = prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.register_buffer('h0_col', filts[0])
        self.register_buffer('h1_col', filts[1])
        self.register_buffer('h0_row', filts[2])
        self.register_buffer('h1_row', filts[3])
        self.J = J

    def forward(self, x):
        yh = []
        ll = x
        for j in range(self.J):
            ll, high = AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row)
            yh.append(high)
        return ll, yh

class WPT2D(torch.nn.Module):
    def __init__(self, wt=DWT2DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt  = wt
        self.J = J
    def analysis_one_level(self,x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2),H[0]],dim=2)
        X = einops.rearrange(X, 'b c f h w -> b (c f) h w')
        return X
    def wavelet_analysis(self,x,J):
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x
    def forward(self, x):
        return self.wavelet_analysis(x,J=self.J)

class DWT2DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_col, g1_col = wave.rec_lo, wave.rec_hi
            g0_row, g1_row = g0_col, g1_col
        else:
            if len(wave) == 2:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = g0_col, g1_col
            elif len(wave) == 4:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = wave[2], wave[3]
        filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
        self.register_buffer('g0_col', filts[0])
        self.register_buffer('g1_col', filts[1])
        self.register_buffer('g0_row', filts[2])
        self.register_buffer('g1_row', filts[3])

    def forward(self, coeffs):
        yl, yh = coeffs
        ll = yl
        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device)
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[...,:-1,:]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[...,:-1]
            ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row)
        return ll

class IWPT2D(torch.nn.Module):
    def __init__(self, iwt=DWT2DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt  = iwt
        self.J = J
    def synthesis_one_level(self,X):
        X = einops.rearrange(X, 'b (c f) h w -> b c f h w', f=4)
        L, H = torch.split(X, [1, 3], dim=2)
        L = L.squeeze(2)
        H = [H]
        y = self.iwt((L, H))
        return y
    def wavelet_synthesis(self,x,J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x
    def forward(self, x):
        return self.wavelet_synthesis(x,J=self.J)

In [6]:
x2d = torch.randn(2, 3, 64, 64)
wt2d = DWT2DForward(wave='bior4.4')
wpt2d = WPT2D(wt=wt2d, J=3)
iwt2d = DWT2DInverse(wave='bior4.4')
iwpt2d = IWPT2D(iwt=iwt2d, J=3)
with torch.no_grad():
    X2d = wpt2d(x2d)
    xhat2d = iwpt2d(X2d)
assert (xhat2d - x2d).abs().max() < 1e-5

In [7]:
def prep_filt_afb3d(h0_x, h1_x,
                    h0_y=None, h1_y=None,
                    h0_z=None, h1_z=None,
                    device=None):
    # If not provided, default Y/Z filters to the X filters
    if h0_y is None or h1_y is None:
        h0_y, h1_y = h0_x, h1_x
    if h0_z is None or h1_z is None:
        h0_z, h1_z = h0_x, h1_x
    # Prepare 1D filters for each dimension
    h0_x, h1_x = prep_filt_afb1d(h0_x, h1_x, device=device)
    h0_y, h1_y = prep_filt_afb1d(h0_y, h1_y, device=device)
    h0_z, h1_z = prep_filt_afb1d(h0_z, h1_z, device=device)
    h0_x = h0_x.reshape(1, 1, -1, 1, 1)
    h1_x = h1_x.reshape(1, 1, -1, 1, 1)
    h0_y = h0_y.reshape(1, 1, 1, -1, 1)
    h1_y = h1_y.reshape(1, 1, 1, -1, 1)
    h0_z = h0_z.reshape(1, 1, 1, 1, -1)
    h1_z = h1_z.reshape(1, 1, 1, 1, -1)
    print("prep_filt_afb3d:", "h0_x shape:", h0_x.shape, 
          "h0_y shape:", h0_y.shape, "h0_z shape:", h0_z.shape)
    return h0_x, h1_x, h0_y, h1_y, h0_z, h1_z

def afb3d(x, filts):
    print("afb3d: input shape:", x.shape)
    # Unpack filters
    if len(filts) == 2:
        h0 = filts[0]
        h1 = filts[1]
        h0_x, h1_x = h0, h1
        h0_y, h1_y = h0, h1
        h0_z, h1_z = h0, h1
    elif len(filts) == 6:
        h0_x, h1_x, h0_y, h1_y, h0_z, h1_z = filts
    else:
        raise ValueError("Unknown form for input filts; expected length 2 or 6.")

    B, C, D, H, W = x.shape
    print("afb3d: using filters with shapes:",
          "h0_x:", h0_x.shape, "h0_y:", h0_y.shape, "h0_z:", h0_z.shape)

    def _afb1d_along_axis(x, h0, h1, axis):
        print(f"_afb1d_along_axis: axis {axis}, input shape: {x.shape}")
        B, C, D, H, W = x.shape
        if axis == 4:  # along z-axis (width)
            x_reshaped = x.reshape(B * D * H, C, W)
            x_reshaped = x_reshaped.unsqueeze(2)
            print(f"_afb1d_along_axis (axis 4): reshaped to {x_reshaped.shape}")
            out = afb1d(x_reshaped, h0, h1, dim=3)
            print(f"_afb1d_along_axis (axis 4): after afb1d, shape: {out.shape}")
            new_W = out.shape[-1]
            out = out.squeeze(2)
            out = out.reshape(B, D, H, 2 * C, new_W).permute(0, 3, 1, 2, 4)
            print(f"_afb1d_along_axis (axis 4): output reshaped to {out.shape}")
            return out
        elif axis == 3:  # along y-axis (height)
            x_reshaped = x.reshape(B * D * W, C, H)
            x_reshaped = x_reshaped.unsqueeze(2)
            print(f"_afb1d_along_axis (axis 3): reshaped to {x_reshaped.shape}")
            out = afb1d(x_reshaped, h0, h1, dim=3)
            print(f"_afb1d_along_axis (axis 3): after afb1d, shape: {out.shape}")
            new_H = out.shape[-1]
            out = out.squeeze(2)
            out = out.reshape(B, D, W, 2 * C, new_H).permute(0, 3, 1, 4, 2)
            print(f"_afb1d_along_axis (axis 3): output reshaped to {out.shape}")
            return out
        elif axis == 2:  # along x-axis (depth)
            x_reshaped = x.reshape(B * H * W, C, D)
            x_reshaped = x_reshaped.unsqueeze(2)
            print(f"_afb1d_along_axis (axis 2): reshaped to {x_reshaped.shape}")
            out = afb1d(x_reshaped, h0, h1, dim=3)
            print(f"_afb1d_along_axis (axis 2): after afb1d, shape: {out.shape}")
            new_D = out.shape[-1]
            out = out.squeeze(2)
            out = out.reshape(B, H, W, 2 * C, new_D).permute(0, 3, 4, 1, 2)
            print(f"_afb1d_along_axis (axis 2): output reshaped to {out.shape}")
            return out
        else:
            raise ValueError("Axis must be one of (2,3,4) for 3D input.")

    out_z = _afb1d_along_axis(x, h0_z, h1_z, axis=4)
    out_y = _afb1d_along_axis(out_z, h0_y, h1_y, axis=3)
    out_x = _afb1d_along_axis(out_y, h0_x, h1_x, axis=2)
    B2, ch8, new_D, new_H, new_W = out_x.shape
    C_orig = ch8 // 8
    out_x = out_x.reshape(B, C_orig, 8, new_D, new_H, new_W)
    low = out_x[:, :, 0, :, :, :]
    highs = out_x[:, :, 1:, :, :, :]
    print("afb3d: output low shape:", low.shape, "highs shape:", highs.shape)
    return low, highs

class AFB3D(Function):
    @staticmethod
    def forward(ctx, x, h0_x, h1_x, h0_y, h1_y, h0_z, h1_z):
        print("AFB3D.forward: input shape:", x.shape)
        ctx.save_for_backward(h0_x, h1_x, h0_y, h1_y, h0_z, h1_z)
        ctx.original_shape = x.shape[-3:]
        low, highs = afb3d(x, (h0_x, h1_x, h0_y, h1_y, h0_z, h1_z))
        print("AFB3D.forward: low shape:", low.shape, "highs shape:", highs.shape)
        return low, highs

    @staticmethod
    def backward(ctx, dlow, dhigh):
        h0_x, h1_x, h0_y, h1_y, h0_z, h1_z = ctx.saved_tensors
        print("AFB3D.backward: dlow shape:", dlow.shape, "dhigh shape:", dhigh.shape)
        dx = sfb3d(dlow, dhigh, (h0_x, h1_x, h0_y, h1_y, h0_z, h1_z))
        D, H, W = ctx.original_shape
        if dx.shape[-3] > D:
            dx = dx[..., :D, :, :]
        if dx.shape[-2] > H:
            dx = dx[..., :H, :]
        if dx.shape[-1] > W:
            dx = dx[..., :W]
        print("AFB3D.backward: dx shape after crop:", dx.shape)
        return dx, None, None, None, None, None, None

def prep_filt_sfb3d(g0_x, g1_x,
                    g0_y=None, g1_y=None,
                    g0_z=None, g1_z=None,
                    device=None):
    if g0_y is None or g1_y is None:
        g0_y, g1_y = g0_x, g1_x
    if g0_z is None or g1_z is None:
        g0_z, g1_z = g0_x, g1_x
    g0_x, g1_x = prep_filt_sfb1d(g0_x, g1_x, device=device)
    g0_y, g1_y = prep_filt_sfb1d(g0_y, g1_y, device=device)
    g0_z, g1_z = prep_filt_sfb1d(g0_z, g1_z, device=device)
    g0_x = g0_x.reshape(1, 1, -1, 1, 1)
    g1_x = g1_x.reshape(1, 1, -1, 1, 1)
    g0_y = g0_y.reshape(1, 1, 1, -1, 1)
    g1_y = g1_y.reshape(1, 1, 1, -1, 1)
    g0_z = g0_z.reshape(1, 1, 1, 1, -1)
    g1_z = g1_z.reshape(1, 1, 1, 1, -1)
    print("prep_filt_sfb3d:", "g0_x shape:", g0_x.shape, 
          "g0_y shape:", g0_y.shape, "g0_z shape:", g0_z.shape)
    return g0_x, g1_x, g0_y, g1_y, g0_z, g1_z

def _sfb1d_along_axis(x, g0, g1, axis):
    print(f"_sfb1d_along_axis: axis {axis}, input shape: {x.shape}")
    Bx, Cx, Dx, Hx, Wx = x.shape
    assert Cx % 2 == 0, "Channel dimension must be even for synthesis along an axis."
    if axis == 4:
        x_perm = x.permute(0, 2, 3, 1, 4).contiguous()
        x_reshaped = x_perm.view(Bx * Dx * Hx, Cx, Wx)
    elif axis == 3:
        x_perm = x.permute(0, 2, 4, 1, 3).contiguous()
        x_reshaped = x_perm.view(Bx * Dx * Wx, Cx, Hx)
    elif axis == 2:
        x_perm = x.permute(0, 3, 4, 1, 2).contiguous()
        x_reshaped = x_perm.view(Bx * Hx * Wx, Cx, Dx)
    else:
        raise ValueError("Axis must be one of 2, 3, or 4.")
    print(f"_sfb1d_along_axis: after reshaping, shape: {x_reshaped.shape}")
    lo = x_reshaped[:, :Cx // 2, :]
    hi = x_reshaped[:, Cx // 2:, :]
    lo = lo.unsqueeze(2)
    hi = hi.unsqueeze(2)
    y = sfb1d(lo, hi, g0, g1, dim=3)
    print(f"_sfb1d_along_axis: after sfb1d, shape: {y.shape}")
    y = y.squeeze(2)
    new_L = y.shape[-1]
    if axis == 4:
        y = y.view(Bx, Dx, Hx, Cx // 2, new_L).permute(0, 3, 1, 2, 4).contiguous()
    elif axis == 3:
        y = y.view(Bx, Dx, Wx, Cx // 2, new_L).permute(0, 3, 1, 4, 2).contiguous()
    elif axis == 2:
        y = y.view(Bx, Hx, Wx, Cx // 2, new_L).permute(0, 3, 4, 1, 2).contiguous()
    print(f"_sfb1d_along_axis: output reshaped, shape: {y.shape}")
    return y

def sfb3d(low, highs, filts):
    print("sfb3d: low shape:", low.shape, "highs shape:", highs.shape)
    Y = torch.cat([low.unsqueeze(2), highs], dim=2)
    print("sfb3d: after concatenation, Y shape:", Y.shape)
    B, C, eight, D, H, W = Y.shape
    Y = Y.view(B, C * eight, D, H, W)
    print("sfb3d: after merging subband axis, Y shape:", Y.shape)
    g0_x, g1_x, g0_y, g1_y, g0_z, g1_z = filts
    Y = _sfb1d_along_axis(Y, g0_x, g1_x, axis=2)
    print("sfb3d: after synthesis along depth, shape:", Y.shape)
    Y = _sfb1d_along_axis(Y, g0_y, g1_y, axis=3)
    print("sfb3d: after synthesis along height, shape:", Y.shape)
    Y = _sfb1d_along_axis(Y, g0_z, g1_z, axis=4)
    print("sfb3d: after synthesis along width, shape:", Y.shape)
    return Y

class SFB3D(Function):
    @staticmethod
    def forward(ctx, low, highs, g0_x, g1_x, g0_y, g1_y, g0_z, g1_z):
        print("SFB3D.forward: low shape:", low.shape, "highs shape:", highs.shape)
        ctx.save_for_backward(g0_x, g1_x, g0_y, g1_y, g0_z, g1_z)
        y = sfb3d(low, highs, (g0_x, g1_x, g0_y, g1_y, g0_z, g1_z))
        print("SFB3D.forward: output y shape:", y.shape)
        return y

    @staticmethod
    def backward(ctx, dy):
        g0_x, g1_x, g0_y, g1_y, g0_z, g1_z = ctx.saved_tensors
        print("SFB3D.backward: dy shape:", dy.shape)
        dlow, dhigh = afb3d(dy, (g0_x, g1_x, g0_y, g1_y, g0_z, g1_z))
        print("SFB3D.backward: dlow shape:", dlow.shape, "dhigh shape:", dhigh.shape)
        return dlow, dhigh, None, None, None, None, None, None

# (Optionally, add debug prints to the 3D forward/inverse modules)

class DWT3DForward(nn.Module):
    def __init__(self, J=1, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_x, h1_x = wave.dec_lo, wave.dec_hi
            h0_y, h1_y = h0_x, h1_x
            h0_z, h1_z = h0_x, h1_x
        else:
            if len(wave) == 2:
                h0_x, h1_x = wave[0], wave[1]
                h0_y, h1_y = h0_x, h1_x
                h0_z, h1_z = h0_x, h1_x
            elif len(wave) == 6:
                h0_x, h1_x = wave[0], wave[1]
                h0_y, h1_y = wave[2], wave[3]
                h0_z, h1_z = wave[4], wave[5]
        filts = prep_filt_afb3d(h0_x, h1_x, h0_y, h1_y, h0_z, h1_z)
        self.register_buffer('h0_x', filts[0])
        self.register_buffer('h1_x', filts[1])
        self.register_buffer('h0_y', filts[2])
        self.register_buffer('h1_y', filts[3])
        self.register_buffer('h0_z', filts[4])
        self.register_buffer('h1_z', filts[5])
        self.J = J

    def forward(self, x):
        print("DWT3DForward.forward: input shape:", x.shape)
        highs = []
        ll = x
        for j in range(self.J):
            ll, high = AFB3D.apply(ll, self.h0_x, self.h1_x,
                                        self.h0_y, self.h1_y,
                                        self.h0_z, self.h1_z)
            print(f"DWT3DForward.forward: level {j} low shape:", ll.shape, "high shape:", high.shape)
            highs.append(high)
        print("DWT3DForward.forward: final low shape:", ll.shape)
        return ll, highs

class WPT3D(torch.nn.Module):
    def __init__(self, wt=DWT3DForward(wave='bior4.4'), J=4):
        super().__init__()
        self.wt = wt
        self.J = J

    def analysis_one_level(self, x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2), H[0]], dim=2)
        X = einops.rearrange(X, 'b c f d h w -> b (c f) d h w')
        print("WPT3D.analysis_one_level: output shape:", X.shape)
        return X

    def wavelet_analysis(self, x, J):
        for j in range(J):
            print(f"WPT3D.wavelet_analysis: level {j} input shape:", x.shape)
            x = self.analysis_one_level(x)
        return x

    def forward(self, x):
        print("WPT3D.forward: input shape:", x.shape)
        out = self.wavelet_analysis(x, J=self.J)
        print("WPT3D.forward: output shape:", out.shape)
        return out

class DWT3DInverse(nn.Module):
    def __init__(self, wave='db1'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_x, g1_x = wave.rec_lo, wave.rec_hi
            g0_y, g1_y = g0_x, g1_x
            g0_z, g1_z = g0_x, g1_x
        else:
            if len(wave) == 2:
                g0_x, g1_x = wave[0], wave[1]
                g0_y, g1_y = g0_x, g1_x
                g0_z, g1_z = g0_x, g1_x
            elif len(wave) == 6:
                g0_x, g1_x = wave[0], wave[1]
                g0_y, g1_y = wave[2], wave[3]
                g0_z, g1_z = wave[4], wave[5]
        filts = prep_filt_sfb3d(g0_x, g1_x, g0_y, g1_y, g0_z, g1_z)
        self.register_buffer('g0_x', filts[0])
        self.register_buffer('g1_x', filts[1])
        self.register_buffer('g0_y', filts[2])
        self.register_buffer('g1_y', filts[3])
        self.register_buffer('g0_z', filts[4])
        self.register_buffer('g1_z', filts[5])

    def forward(self, coeffs):
        ll, yh = coeffs
        print("DWT3DInverse.forward: ll shape:", ll.shape)
        for i, h in enumerate(yh[::-1]):
            print(f"DWT3DInverse.forward: level {i} high shape:", h.shape if h is not None else "None")
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 7,
                                ll.shape[-3], ll.shape[-2], ll.shape[-1],
                                device=ll.device, dtype=ll.dtype)
            if ll.shape[-3] > h.shape[-3]:
                ll = ll[..., :-1, :, :]
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[..., :, :-1, :]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[..., :, :, :-1]
            ll = SFB3D.apply(ll, h,
                             self.g0_x, self.g1_x,
                             self.g0_y, self.g1_y,
                             self.g0_z, self.g1_z)
            print(f"DWT3DInverse.forward: after level {i}, ll shape:", ll.shape)
        return ll

class IWPT3D(torch.nn.Module):
    def __init__(self, iwt=DWT3DInverse(wave='bior4.4'), J=4):
        super().__init__()
        self.iwt = iwt
        self.J = J

    def synthesis_one_level(self, X):
        print("IWPT3D.synthesis_one_level: input shape:", X.shape)
        X = einops.rearrange(X, 'b (c f) d h w -> b c f d h w', f=8)
        L, H = torch.split(X, [1, 7], dim=2)
        L = L.squeeze(2)
        H = H.squeeze(2)
        print("IWPT3D.synthesis_one_level: L shape:", L.shape, "H shape:", H.shape)
        y = self.iwt((L, [H]))
        print("IWPT3D.synthesis_one_level: output shape:", y.shape)
        return y

    def wavelet_synthesis(self, x, J):
        for j in range(J):
            print(f"IWPT3D.wavelet_synthesis: level {j} input shape:", x.shape)
            x = self.synthesis_one_level(x)
        return x

    def forward(self, x):
        print("IWPT3D.forward: input shape:", x.shape)
        out = self.wavelet_synthesis(x, J=self.J)
        print("IWPT3D.forward: output shape:", out.shape)
        return out


prep_filt_afb3d: h0_x shape: torch.Size([1, 1, 10, 1, 1]) h0_y shape: torch.Size([1, 1, 1, 10, 1]) h0_z shape: torch.Size([1, 1, 1, 1, 10])
prep_filt_sfb3d: g0_x shape: torch.Size([1, 1, 10, 1, 1]) g0_y shape: torch.Size([1, 1, 1, 10, 1]) g0_z shape: torch.Size([1, 1, 1, 1, 10])


In [8]:
x3d = torch.randn(2, 3, 16, 16, 16)
wt3d = DWT3DForward(wave='bior4.4')
wpt3d = WPT3D(wt=wt3d, J=3)
iwt3d = DWT3DInverse(wave='bior4.4')
iwpt3d = IWPT3D(iwt=iwt3d, J=3)
with torch.no_grad():
    X3d = wpt3d(x3d)
    xhat3d = iwpt3d(X3d)
assert (xhat3d - x3d).abs().max() < 1e-5

prep_filt_afb3d: h0_x shape: torch.Size([1, 1, 10, 1, 1]) h0_y shape: torch.Size([1, 1, 1, 10, 1]) h0_z shape: torch.Size([1, 1, 1, 1, 10])
prep_filt_sfb3d: g0_x shape: torch.Size([1, 1, 10, 1, 1]) g0_y shape: torch.Size([1, 1, 1, 10, 1]) g0_z shape: torch.Size([1, 1, 1, 1, 10])
WPT3D.forward: input shape: torch.Size([2, 3, 16, 16, 16])
WPT3D.wavelet_analysis: level 0 input shape: torch.Size([2, 3, 16, 16, 16])
DWT3DForward.forward: input shape: torch.Size([2, 3, 16, 16, 16])
AFB3D.forward: input shape: torch.Size([2, 3, 16, 16, 16])
afb3d: input shape: torch.Size([2, 3, 16, 16, 16])
afb3d: using filters with shapes: h0_x: torch.Size([1, 1, 10, 1, 1]) h0_y: torch.Size([1, 1, 1, 10, 1]) h0_z: torch.Size([1, 1, 1, 1, 10])
_afb1d_along_axis: axis 4, input shape: torch.Size([2, 3, 16, 16, 16])
_afb1d_along_axis (axis 4): reshaped to torch.Size([512, 3, 1, 16])
_afb1d_along_axis (axis 4): after afb1d, shape: torch.Size([512, 6, 1, 8])
_afb1d_along_axis (axis 4): output reshaped to torch.Siz

AssertionError: 