In [2]:
import torch
from wt import DWT1DForward, DWT1DInverse, DWT2DForward, DWT2DInverse, DWT3DForward, DWT3DInverse 
from tft.transforms import WPT1D, WPT2D, IWPT1D, IWPT2D

import einops
class WPT3D(torch.nn.Module):
    def __init__(self, wt=DWT3DForward(J=1, mode='periodization', wave='bior4.4'), J=4):
        super().__init__()
        self.wt = wt
        self.J = J
    def analysis_one_level(self, x):
        L, H = self.wt(x)
        X = torch.cat([L.unsqueeze(2), H[0]], dim=2)
        X = einops.rearrange(X, 'b c f d h w -> b (c f) d h w', f=8)
        return X

    def wavelet_analysis(self, x, J):
        for _ in range(J):
            x = self.analysis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_analysis(x, self.J)

class IWPT3D(torch.nn.Module):
    def __init__(self, iwt=DWT3DInverse(mode='periodization', wave='bior4.4'), J=4):
        super().__init__()
        self.iwt = iwt
        self.J = J

    def synthesis_one_level(self, X):
        X = einops.rearrange(X, 'b (c f) d h w -> b c f d h w', f=8)
        L, H = torch.split(X, [1, 7], dim=2) 
        L = L.squeeze(2)
        H = [H]
        y = self.iwt((L, H))
        return y

    def wavelet_synthesis(self, x, J):
        for _ in range(J):
            x = self.synthesis_one_level(x)
        return x

    def forward(self, x):
        return self.wavelet_synthesis(x, self.J)

wt1d = DWT1DForward(J=1, mode='periodization', wave='bior4.4')
wpt1d = WPT1D(wt=wt1d, J=8)
iwt1d = DWT1DInverse(mode='periodization', wave='bior4.4')
iwpt1d = IWPT1D(iwt=iwt1d, J=8)

wt2d = DWT2DForward(J=1, mode='periodization', wave='bior4.4')
wpt2d = WPT2D(wt=wt2d, J=4)
iwt2d = DWT2DInverse(mode='periodization', wave='bior4.4')
iwpt2d = IWPT2D(iwt=iwt2d, J=4)

wt3d = DWT3DForward(J=1, mode='periodization', wave='bior4.4')
wpt3d = WPT3D(wt=wt3d, J=3)
iwt3d = DWT3DInverse(mode='periodization', wave='bior4.4')
iwpt3d = IWPT3D(iwt=iwt3d, J=3)

b = 4; c = 12; d = 64; h = 384; w = 256; ℓ = 49152
x1d = torch.randn(b, c, ℓ)
x2d = torch.randn(b, c, h, w)
x3d = torch.randn(b, c, d, d, d)

with torch.no_grad():
    X1d = wpt1d(x1d)
    xhat1d = iwpt1d(X1d)
    X2d = wpt2d(x2d)
    xhat2d = iwpt2d(X2d)
    X3d = wpt3d(x3d)
    xhat3d = iwpt3d(X3d)

assert (xhat1d - x1d).abs().max() < 1e-5
assert (xhat2d - x2d).abs().max() < 1e-5
assert (xhat3d - x3d).abs().max() < 1e-5

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [4, 12, 64, 64, 64]