In [1]:
from torchinfo import summary
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torch3dseg.utils.model import get_model
import torch

In [2]:
config = {
    "model": {
      "name": "UNet3D",
      # number of input channels to the model
      "in_channels": 1,
      # number of output channels
      "out_channels": 3,
      # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
      "layer_order": "gcr",
      # initial number of feature maps
      "f_maps": 32, 
      # number of groups in the groupnorm
      "num_groups": 8,
      # number of levels in the encoder/decoder path (applied only if f_maps is an int)
      "num_levels":4,
      # down-pooling type for encoder branch: ["max", "avg", "conv"] 
      "pool_type":'conv',
      # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
      "final_sigmoid": False,
      # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
      "is_segmentation": True,
    }
  }

In [3]:
model = get_model(config['model'])

In [4]:
summary(model,input_size=(1,1,128,128,128), depth=5,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
UNet3D                                        [1, 3, 128, 128, 128]     --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 32, 128, 128, 128]    --
│    │    └─DoubleConv: 3-1                   [1, 32, 128, 128, 128]    --
│    │    │    └─SingleConv: 4-1              [1, 16, 128, 128, 128]    --
│    │    │    │    └─GroupNorm: 5-1          [1, 1, 128, 128, 128]     2
│    │    │    │    └─Conv3d: 5-2             [1, 16, 128, 128, 128]    432
│    │    │    │    └─ReLU: 5-3               [1, 16, 128, 128, 128]    --
│    │    │    └─SingleConv: 4-2              [1, 32, 128, 128, 128]    --
│    │    │    │    └─GroupNorm: 5-4          [1, 16, 128, 128, 128]    32
│    │    │    │    └─Conv3d: 5-5             [1, 32, 128, 128, 128]    13,824
│    │    │    │    └─ReLU: 5-6               [1, 32, 128, 128, 128]    --
│    └─Encoder: 