In [1]:
from torchinfo import summary
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torch3dseg.utils import model
from torch3dseg.utils.model import get_model
import torch

In [2]:
config = {
    "model": {
      "name": "ResidualUNet3D",
      # number of input channels to the model
      "in_channels": 1,
      # number of output channels
      "out_channels": 3,
      # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
      "layer_order": "gcr",
      # initial number of feature maps
      "f_maps": 64, 
      # number of groups in the groupnorm
      "num_groups": 8,
      # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
      "final_sigmoid": False,
      # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
      "is_segmentation": True,
    }
  }

In [7]:
model = get_model(config['model'])

In [8]:
summary(model,input_size=(1,1,128,128,128), depth=5,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
ResidualUNet3D                                [1, 3, 128, 128, 128]     --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 64, 128, 128, 128]    --
│    │    └─ExtResNetBlock: 3-1               [1, 64, 128, 128, 128]    --
│    │    │    └─SingleConv: 4-1              [1, 64, 128, 128, 128]    --
│    │    │    │    └─GroupNorm: 5-1          [1, 1, 128, 128, 128]     2
│    │    │    │    └─Conv3d: 5-2             [1, 64, 128, 128, 128]    1,728
│    │    │    │    └─ReLU: 5-3               [1, 64, 128, 128, 128]    --
│    │    │    └─SingleConv: 4-2              [1, 64, 128, 128, 128]    --
│    │    │    │    └─GroupNorm: 5-4          [1, 64, 128, 128, 128]    128
│    │    │    │    └─Conv3d: 5-5             [1, 64, 128, 128, 128]    110,592
│    │    │    │    └─ReLU: 5-6               [1, 64, 128, 128, 128]    --
│    │    │ 

In [6]:
summary(model, input_size=(1, 1, 64, 64, 64))

Layer (type:depth-idx)                        Output Shape              Param #
ResidualUNet3D                                [1, 1, 64, 64, 64]        --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 64, 64, 64, 64]       --
│    │    └─ExtResNetBlock: 3-1               [1, 64, 64, 64, 64]       223,170
│    └─Encoder: 2-2                           [1, 128, 32, 32, 32]      --
│    │    └─MaxPool3d: 3-2                    [1, 64, 32, 32, 32]       --
│    │    └─ExtResNetBlock: 3-3               [1, 128, 32, 32, 32]      1,106,560
│    └─Encoder: 2-3                           [1, 256, 16, 16, 16]      --
│    │    └─MaxPool3d: 3-4                    [1, 128, 16, 16, 16]      --
│    │    └─ExtResNetBlock: 3-5               [1, 256, 16, 16, 16]      4,424,960
│    └─Encoder: 2-4                           [1, 512, 8, 8, 8]         --
│    │    └─MaxPool3d: 3-6                    [1, 256, 8, 8, 8]         --
│