In [1]:
import torch
import cmz.torch.unet as unet

# dummy data
X = torch.rand((4, 3, 128, 128))

net = unet.DownBlock(3, 6)
Y = net(X)
print('Down block:')
print(X.shape, '->', Y.shape)

net = unet.UpBlock(6, 3)
print('\nUp block:')
print(Y.shape, '+', X.shape, '->', net(Y, X).shape)

net = unet.UBlock(3, 6)
print('\nU-block:')
print(X.shape, '->', net(X, None).shape)

net = unet.UNet(3, 1, hidden_channels=[4, 8, 16])
print('\nU-Net:')
print(X.shape, '->', net(X).shape)

expected_shape = torch.Size([4, 1, 128, 128])

print('\nTesting different hidden_channels:')
fmt = "  {:<15} {:<12} {}"
for h, label in [
    ([4, 6, 8], "not doubling"),
    ([4, 2, 4], "going down"),
    ([1, 1], "ones"),
]:
    net = unet.UNet(3, 1, hidden_channels=h)
    print(
        fmt.format(
            f"{label}:",
            f"{h}",
            ("good! " if net(X).shape == expected_shape else "weird shape: ") + str(expected_shape)
        )
    )

Down block:
torch.Size([4, 3, 128, 128]) -> torch.Size([4, 6, 64, 64])

Up block:
torch.Size([4, 6, 64, 64]) + torch.Size([4, 3, 128, 128]) -> torch.Size([4, 3, 128, 128])

U-block:
torch.Size([4, 3, 128, 128]) -> torch.Size([4, 3, 128, 128])

U-Net:
torch.Size([4, 3, 128, 128]) -> torch.Size([4, 1, 128, 128])

Testing different hidden_channels:
  not doubling:   [4, 6, 8]    good! torch.Size([4, 1, 128, 128])
  going down:     [4, 2, 4]    good! torch.Size([4, 1, 128, 128])
  ones:           [1, 1]       good! torch.Size([4, 1, 128, 128])


  cpu = _conversion_method_template(device=torch.device("cpu"))
