In [5]:
from torchinfo import summary
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torch3dseg.utils import model
from torch3dseg.utils.model import get_model
import torch

In [13]:
config = {
    "model": {
      "name": "UNet3D",
      # number of input channels to the model
      "in_channels": 1,
      # number of output channels
      "out_channels": 3,
      # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
      "layer_order": "gcr",
      # initial number of feature maps
      "f_maps": 64, 
      # number of groups in the groupnorm
      "num_groups": 8,
      # number of levels in the encoder/decoder path (applied only if f_maps is an int)
      "num_levels":3,
      "pool_type":'conv',
      # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
      "final_sigmoid": False,
      # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
      "is_segmentation": True,
    }
  }

In [14]:
model = get_model(config['model'])

In [15]:
summary(model,input_size=(1,1,64,64,64), depth=5,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
UNet3D                                        [1, 3, 64, 64, 64]        --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 64, 64, 64, 64]       --
│    │    └─DoubleConv: 3-1                   [1, 64, 64, 64, 64]       --
│    │    │    └─SingleConv: 4-1              [1, 32, 64, 64, 64]       --
│    │    │    │    └─GroupNorm: 5-1          [1, 1, 64, 64, 64]        2
│    │    │    │    └─Conv3d: 5-2             [1, 32, 64, 64, 64]       864
│    │    │    │    └─ReLU: 5-3               [1, 32, 64, 64, 64]       --
│    │    │    └─SingleConv: 4-2              [1, 64, 64, 64, 64]       --
│    │    │    │    └─GroupNorm: 5-4          [1, 32, 64, 64, 64]       64
│    │    │    │    └─Conv3d: 5-5             [1, 64, 64, 64, 64]       55,296
│    │    │    │    └─ReLU: 5-6               [1, 64, 64, 64, 64]       --
│    └─Encoder: 