In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.winnet.splitmerge import (
    DCTSplitMerge,
    StationaryHaarSplitMerge,
    MultiScaleStationaryHaar,
    LearnableHaarSplitMerge,
)

In [2]:
def verify_forward_inverse_pass(splitmerge, in_ch=1):
    x = torch.randn(8, in_ch, 32, 32)
    x_coarse, x_detail = splitmerge(x)
    x_rec = splitmerge.inverse(x_coarse, x_detail)
    reconstruction_error = (x - x_rec).abs().max().item()
    
    print(f"coarse, detail shapes: {x_coarse.shape}, {x_detail.shape}")
    print("reconstruction error:", reconstruction_error) 

## DCTSplitMerge

In [3]:
splitmerge = DCTSplitMerge(
    in_channels=3,
    coarse_to_in_ch_ratio=2,
    patch_size=4,
)

verify_forward_inverse_pass(splitmerge, in_ch=3)

coarse, detail shapes: torch.Size([8, 6, 32, 32]), torch.Size([8, 42, 32, 32])
reconstruction error: 6.348522186279297


## StationaryHaarSplitMerge

In [4]:
splitmerge = StationaryHaarSplitMerge(
    in_channels=1,
)

verify_forward_inverse_pass(splitmerge, in_ch=1)

coarse, detail shapes: torch.Size([8, 1, 32, 32]), torch.Size([8, 3, 32, 32])
reconstruction error: 2.9933762550354004


## MultiScaleStationaryHaar

In [11]:
splitmerge = MultiScaleStationaryHaar(
    in_channels=1,
    coarse_to_in_ch_ratio=3,
    num_scales=4,
)

verify_forward_inverse_pass(splitmerge, in_ch=1)

coarse, detail shapes: torch.Size([8, 3, 32, 32]), torch.Size([8, 10, 32, 32])
reconstruction error: 3.67927885055542


## LearnableHaarSplitMerge

In [16]:
splitmerge = LearnableHaarSplitMerge(
    in_channels=1, 
    num_filters=16, 
    coarse_to_in_ch_ratio=4, 
)

verify_forward_inverse_pass(splitmerge, in_ch=1)

coarse, detail shapes: torch.Size([8, 4, 32, 32]), torch.Size([8, 12, 32, 32])
reconstruction error: 2.6168160438537598


In [None]:
layer = LearnableHaarSplitMerge(
    in_channels=3, 
    num_filters=6, 
    coarse_to_in_ch_ratio=1, 
)

x = torch.randn(8, 3, 32, 32)            # mini-batch
x_c, x_d = layer(x)                        # split
x_rec     = layer.inverse(x_c, x_d)        # merge

# example training step
recon_loss = F.mse_loss(x_rec, x)
ortho_loss = 1e-3 * layer.orthogonality_loss()
# loss = recon_loss + ortho_loss
# loss.backward()
# optimizer.step()
reconstruction_error = (x - x_rec).abs().max().item()
print("error:", reconstruction_error)      # should be 0 (up to fp-roundoff)

In [19]:
print(x_c.shape, x_d.shape)

torch.Size([8, 3, 32, 32]) torch.Size([8, 15, 32, 32])


In [46]:
def print_loss_components(loss_components, end=", "):
    print("\t", end="")
    for name, loss_val in loss_components.items():
        if loss_val > 0.0:
            print(f"{name}: {loss_val:.4f}", end=end) 
    print("")

In [47]:
comps = {'reconstruction_loss': 0.014870839193463326, 'total_loss': 0.015197951346635818, 'clista_orthgonal': 0.0, 'splitmerge_orthogonal': 0.0, 'lifting_spectral_norm': 3.271117925643921, 'spitmerge_orthogonal': 0.0962127223610878}

In [48]:
print_loss_components(comps)
print_loss_components(comps)

	reconstruction_loss: 0.0149, total_loss: 0.0152, lifting_spectral_norm: 3.2711, spitmerge_orthogonal: 0.0962, 
	reconstruction_loss: 0.0149, total_loss: 0.0152, lifting_spectral_norm: 3.2711, spitmerge_orthogonal: 0.0962, 
