In [1]:
from torchinfo import summary
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torch3dseg.utils import model
from torch3dseg.utils.model import get_model
import torch

In [2]:
config = {
    "model": {
      "name": "ResidualUNet3D",
      # number of input channels to the model
      "in_channels": 1,
      # number of output channels
      "out_channels": 3,
      # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
      "layer_order": "gcr",
      # initial number of feature maps
      "f_maps": 64, 
      # number of groups in the groupnorm
      "num_groups": 8,
      # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
      "final_sigmoid": False,
      # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
      "is_segmentation": True,
    }
  }

In [3]:
model = get_model(config['model'])

In [9]:
summary(model,input_size=(1,1,128,128,128), depth=2,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
ResidualUNet3D                                [1, 3, 128, 128, 128]     --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 64, 128, 128, 128]    223,170
│    └─Encoder: 2-2                           [1, 128, 64, 64, 64]      1,106,560
│    └─Encoder: 2-3                           [1, 256, 32, 32, 32]      4,424,960
│    └─Encoder: 2-4                           [1, 512, 16, 16, 16]      17,697,280
│    └─Encoder: 2-5                           [1, 1024, 8, 8, 8]        70,784,000
├─ModuleList: 1-2                             --                        --
│    └─Decoder: 2-6                           [1, 512, 16, 16, 16]      35,393,024
│    └─Decoder: 2-7                           [1, 256, 32, 32, 32]      8,849,152
│    └─Decoder: 2-8                           [1, 128, 64, 64, 64]      2,212,736
│    └─Decoder: 2-9                   

In [7]:
summary(model, input_size=(1, 1, 16, 16, 16))

Layer (type:depth-idx)                        Output Shape              Param #
ResidualUNet3D                                [1, 3, 16, 16, 16]        --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 64, 16, 16, 16]       --
│    │    └─ExtResNetBlock: 3-1               [1, 64, 16, 16, 16]       223,170
│    └─Encoder: 2-2                           [1, 128, 8, 8, 8]         --
│    │    └─MaxPool3d: 3-2                    [1, 64, 8, 8, 8]          --
│    │    └─ExtResNetBlock: 3-3               [1, 128, 8, 8, 8]         1,106,560
│    └─Encoder: 2-3                           [1, 256, 4, 4, 4]         --
│    │    └─MaxPool3d: 3-4                    [1, 128, 4, 4, 4]         --
│    │    └─ExtResNetBlock: 3-5               [1, 256, 4, 4, 4]         4,424,960
│    └─Encoder: 2-4                           [1, 512, 2, 2, 2]         --
│    │    └─MaxPool3d: 3-6                    [1, 256, 2, 2, 2]         --
│