In [1]:
from torchinfo import summary
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from torch3dseg.utils.model import get_model
import torch

In [2]:
config = {
    "model": {
      "name": "ResidualUNet3D",
      # number of input channels to the model
      "in_channels": 1,
      # number of output channels
      "out_channels": 1,
      # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
      "layer_order": "gcr",
      # initial number of feature maps or list of feature maps per block
      "f_maps": 32,
      # number of groups in the groupnorm
      "num_groups": 8,
      # number of levels in the encoder/decoder path (applied only if f_maps is an int)
      "num_levels":3,
      # down-pooling type for encoder branch: ["max", "avg", "conv"] 
      "pool_type":'conv',
      "transposed_conv": True,
      # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
      "final_sigmoid": False,
      # if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
      "is_segmentation": True,
    }
  }

In [3]:
model = get_model(config['model'])

In [5]:
summary(model,input_size=(1,1,32,32,32), depth=4,device='cpu')

Layer (type:depth-idx)                        Output Shape              Param #
ResidualUNet3D                                [1, 1, 32, 32, 32]        --
├─ModuleList: 1-1                             --                        --
│    └─Encoder: 2-1                           [1, 32, 32, 32, 32]       --
│    │    └─ExtResNetBlock: 3-1               [1, 32, 32, 32, 32]       --
│    │    │    └─SingleConv: 4-1              [1, 32, 32, 32, 32]       866
│    │    │    └─SingleConv: 4-2              [1, 32, 32, 32, 32]       27,712
│    │    │    └─SingleConv: 4-3              [1, 32, 32, 32, 32]       27,712
│    │    │    └─ReLU: 4-4                    [1, 32, 32, 32, 32]       --
│    └─Encoder: 2-2                           [1, 64, 16, 16, 16]       --
│    │    └─Conv3d: 3-2                       [1, 32, 16, 16, 16]       256
│    │    └─ExtResNetBlock: 3-3               [1, 64, 16, 16, 16]       --
│    │    │    └─SingleConv: 4-5              [1, 64, 16, 16, 16]       55,360
│    │

In [8]:
output, logits = model(torch.rand(1,1,32,32,32),return_logits = True)


print(output.shape)

torch.Size([1, 1, 32, 32, 32])
