In [None]:
from mamba_ssm import Mamba
import ptwt, pywt

In [None]:
class MambaLayer(nn.Module):
    def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        self.mamba = Mamba(
                d_model=dim, # Model dimension d_model
                d_state=d_state,  # SSM state expansion factor
                d_conv=d_conv,    # Local convolution width
                expand=expand,    # Block expansion factor
        )

    # def forward(self, x):
    #     if x.dtype == torch.float16:
    #         x = x.type(torch.float32)
    #     B, C = x.shape[:2]
    #     assert C == self.dim
    #     n_tokens = x.shape[2:].numel()
    #     img_dims = x.shape[2:]
    #     x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
    #     x_norm = self.norm(x_flat)
    #     x_mamba = self.mamba(x_norm)
    #     out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)

    #     return out
    def forward(self, x):
        if x.dtype == torch.float16:
            x = x.type(torch.float32)
        B, C, H, W = x.shape
        print(H,W)
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]

        #expand the x
        level = pywt.dwtn_max_level((H, W), 'db2')
        level = 2 if level > 2 else 1

        expand_x = torch.zeros(B, C, H if H%2==0 else H+1, W if W%2==0 else W+1, device=torch.device('cuda'))
        expand_x[..., :H, :W] = x

        coeff = ptwt.wavedec2(expand_x, 'db2', level=level)
        
        #set the initial list and append 1st element
        flat_x = coeff[0]
        dim_coef = [flat_x.shape[2:]]
        token_coef = [flat_x.shape[2:].numel()]
        flat_x = flat_x.reshape(B, C, -1)

        for i in range(1, len(coeff)):
            for j in range(3):
                dim_coef.append(coeff[i][j].shape[2:])
                token_coef.append(coeff[i][j].shape[2:].numel())
                sub_coeff = coeff[i][j].reshape(B, C, -1)
                flat_x = torch.cat((flat_x, sub_coeff), axis=2)

        #list dict element, flatten and concatenate all elements
        input = flat_x.transpose(-1, -2)
        
        #do normalization and mamba layer
        x_norm = self.norm(input)
        x_mamba = self.mamba(x_norm)
        x_mamba = input + x_mamba

        out_x = x_mamba.transpose(-1, -2)
        out_x_split = list(torch.split(out_x, token_coef, dim=2))
        
        rec_x = [out_x_split.pop(0).reshape(B, C, dim_coef[0][0], dim_coef[0][1])]

        for i in range(1, len(coeff)):
            sub_rec_x = []
            for j in range(3):
                sub_rec_x.append(out_x_split.pop(0).reshape(B, C, dim_coef[i * 3 - 2 + j][0], dim_coef[i * 3 - 2 + j][1]))
            rec_x.append(tuple(sub_rec_x))

        rec_coeff = ptwt.waverec2(rec_x, 'db2')
        raw_x = rec_coeff[..., :H, :W]

        #reshape the output back to multiple dim coeff, inverse the wavelet and reshape to H*W*D
        out = raw_x.reshape(B, C, *img_dims)
        out = x + out

        return out