In [1]:
import torch
from models.winnet.lifting import RevNetLiftingPair
from models.winnet.splitmerge import DCTSplitMerge, HaarSplitMerge

In [2]:
B, H, W = 8, 32, 32
x = torch.randn(B, 3, H, W)
sigma = torch.zeros(B, 1)

In [3]:
haar_sm = HaarSplitMerge()

c, d = haar_sm(x)
print(c.shape, d.shape)

torch.Size([8, 3, 16, 16]) torch.Size([8, 9, 16, 16])


In [4]:
split_merge = DCTSplitMerge(
    in_channels=3,
    coarse_to_in_ch_ratio=1,
    patch_size=5,
)

revnet_lift = RevNetLiftingPair(
    coarse_ch=3, 
    detail_ch=72, 
    hidden_ch=96,
)

In [5]:
c, d = split_merge.forward(x)
print(c.shape, d.shape)

torch.Size([8, 3, 32, 32]) torch.Size([8, 72, 32, 32])


In [6]:
c, d = revnet_lift(c, d, sigma)
print(c.shape, d.shape)

torch.Size([8, 3, 32, 32]) torch.Size([8, 72, 32, 32])


In [7]:
c, d = revnet_lift.inverse(c, d, sigma)
x_rec  = split_merge.inverse(c, d)

print((x - x_rec).abs().max())            # â‰ˆ 1e-6 (floating-point noise)

tensor(15.7661, grad_fn=<MaxBackward1>)


In [8]:
x_rec.shape

torch.Size([8, 3, 32, 32])

In [9]:
from models.lifting_denoiser import LiftingDenoiser

In [40]:
# new LiftingDenoiser init params

unet_params = {
    "base_channels": 16,
    "channel_multipliers": (1, 2, 4, 8),
}

clista_params = {
    "latent_channels": 128,
}

init_params = {
    "input_channels": 1,
    "coarse_channels": 3,
    "hidden_channels": 128,
    "num_lifting_steps": 4,
    "lifting_type": "revnet",
    "detail_denoiser": "clista",
    "split_merge_type": "haar_redundant",
    "do_convert_t_to_sigma": True,
    "num_haar_scales": 4,
}

if init_params["detail_denoiser"] == "unet":
    init_params.update(unet_params)
elif init_params["detail_denoiser"] == "clista":
    init_params.update(clista_params)
    
for key, val in init_params.items():
    print(f"{key}: {val}")

input_channels: 1
coarse_channels: 3
hidden_channels: 128
num_lifting_steps: 4
lifting_type: revnet
detail_denoiser: clista
split_merge_type: haar_redundant
do_convert_t_to_sigma: True
num_haar_scales: 4
latent_channels: 128


In [41]:
model = LiftingDenoiser(
    **init_params,
)

In [42]:
param_count = sum(p.numel() for p in model.parameters())
param_count

1616032

In [43]:
model.coarse_channels

3

In [44]:
model.detail_channels

10

In [45]:
model

LiftingDenoiser(
  (split_merge): MultiScaleStationaryHaar()
  (lifting_steps): ModuleList(
    (0-3): 4 x RevNetLiftingPair(
      (f): Conditioner(
        (b1): ConvFiLMBlock(
          (conv): Conv2d(10, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (gn): GroupNorm(8, 128, eps=1e-05, affine=False)
          (film): FiLM(
            (mlp): Sequential(
              (0): Linear(in_features=1, out_features=32, bias=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=256, bias=True)
            )
          )
        )
        (b2): ConvFiLMBlock(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (gn): GroupNorm(8, 128, eps=1e-05, affine=False)
          (film): FiLM(
            (mlp): Sequential(
              (0): Linear(in_features=1, out_features=32, bias=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=256, bias=True)
 

In [46]:
B, H, W = 8, 32, 32
x = torch.randn(B, 1, H, W)
t = torch.zeros(B, dtype=torch.long)
sigma = torch.zeros(B, dtype=torch.float32)

In [47]:
model.update_sigmas_t(sigma, "Successfully updated sigma values.")

Successfully updated sigma values.


In [48]:
out = model(x, t)
out.shape

torch.Size([8, 1, 32, 32])

In [17]:
from models.winnet.splitmerge import StationaryHaarSplitMerge
import torch

In [18]:
stationary_haar_sm = StationaryHaarSplitMerge(
    coarse_to_in_ch_ratio=1, 
    in_channels=3
)

TypeError: StationaryHaarSplitMerge.__init__() got an unexpected keyword argument 'coarse_to_in_ch_ratio'

In [None]:
B, H, W = 8, 32, 32
x = torch.randn(B, 3, H, W)
sigma = torch.zeros(B, 1)

In [None]:
c, d = stationary_haar_sm(x)
print(c.shape, d.shape)

In [None]:
from models.unet import UNet

In [None]:
model = UNet(
    in_channels=3,
    channel_multipliers=(1, 2, 4, 4),
    base_channels=256,
    # attention_heads=8,
    # attention_head_dim=64,
)

In [None]:
param_count = sum(p.numel() for p in model.parameters())
print(f"no. of parameters: {param_count // 1e6} M")