In [1]:
import torch
from unet import UNet

model = UNet(
    in_channels=1,
    out_channels=2,
    n_blocks=4,
    start_filters=32,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=2,
)

x = torch.randn(size=(1, 1, 512, 512), dtype=torch.float32)
with torch.no_grad():
    out = model(x)

print(f"Out: {out.shape}")


Out: torch.Size([1, 2, 512, 512])


In [2]:
from torchinfo import summary

summary = summary(model=model, input_size=(1, 1, 512, 512), device="cpu")
summary

Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     --                        --
├─ModuleList: 1-1                        --                        --
├─ModuleList: 1-2                        --                        --
├─ModuleList: 1-1                        --                        --
│    └─DownBlock: 2-1                    [1, 32, 256, 256]         --
│    │    └─Conv2d: 3-1                  [1, 32, 512, 512]         320
│    │    └─ReLU: 3-2                    [1, 32, 512, 512]         --
│    │    └─BatchNorm2d: 3-3             [1, 32, 512, 512]         64
│    │    └─Conv2d: 3-4                  [1, 32, 512, 512]         9,248
│    │    └─ReLU: 3-5                    [1, 32, 512, 512]         --
│    │    └─BatchNorm2d: 3-6             [1, 32, 512, 512]         64
│    │    └─MaxPool2d: 3-7               [1, 32, 256, 256]         --
│    └─DownBlock: 2-2                    [1, 64, 128, 128]         --
│    │    └

In [3]:
shape = 1920


def compute_max_depth(shape, max_depth=10, print_out=True):
    shapes = []
    shapes.append(shape)
    for level in range(1, max_depth):
        if shape % 2 ** level == 0 and shape / 2 ** level > 1:
            shapes.append(shape / 2 ** level)
            if print_out:
                print(f"Level {level}: {shape / 2 ** level}")
        else:
            if print_out:
                print(f"Max-level: {level - 1}")
            break

    return shapes


out = compute_max_depth(shape, print_out=True, max_depth=10)


Level 1: 960.0
Level 2: 480.0
Level 3: 240.0
Level 4: 120.0
Level 5: 60.0
Level 6: 30.0
Level 7: 15.0
Max-level: 7


In [4]:
low = 10
high = 512
depth = 8


def compute_possible_shapes(low, high, depth):
    possible_shapes = {}
    for shape in range(low, high + 1):
        shapes = compute_max_depth(shape, max_depth=depth, print_out=False)
        if len(shapes) == depth:
            possible_shapes[shape] = shapes

    return possible_shapes


possible_shapes = compute_possible_shapes(low, high, depth)
possible_shapes


{256: [256, 128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0],
 384: [384, 192.0, 96.0, 48.0, 24.0, 12.0, 6.0, 3.0],
 512: [512, 256.0, 128.0, 64.0, 32.0, 16.0, 8.0, 4.0]}

In [5]:
low = 10
high = 512
depth = 8


def compute_possible_shapes(low, high, depth):
    possible_shapes = {}
    for shape in range(low, high + 1):
        shapes = compute_max_depth(shape, max_depth=depth, print_out=False)
        if len(shapes) == depth:
            possible_shapes[shape] = shapes

    return possible_shapes


possible_shapes = compute_possible_shapes(low, high, depth)
possible_shapes


{256: [256, 128.0, 64.0, 32.0, 16.0, 8.0, 4.0, 2.0],
 384: [384, 192.0, 96.0, 48.0, 24.0, 12.0, 6.0, 3.0],
 512: [512, 256.0, 128.0, 64.0, 32.0, 16.0, 8.0, 4.0]}