In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary 

In [2]:

class SFCN(nn.Module):
    def __init__(
        self, channel_number=[32, 64, 128, 256, 256, 64], output_dim=1, dropout_prob=None
    ):
        super().__init__()
        n_layer = len(channel_number)
        self.feature_extractor = nn.Sequential()
        for i in range(n_layer):
            if i == 0:
                in_channel = 1
            else:
                in_channel = channel_number[i - 1]
            out_channel = channel_number[i]
            if i < n_layer - 1:
                self.feature_extractor.add_module(
                    "conv_%d" % i,
                    self.conv_layer(
                        in_channel, out_channel, maxpool=True, kernel_size=3, padding=1
                    ),
                )
            else:
                self.feature_extractor.add_module(
                    "conv_%d" % i,
                    self.conv_layer(
                        in_channel, out_channel, maxpool=False, kernel_size=1, padding=0
                    ),
                )
        self.classifier = nn.Sequential()
        self.classifier.add_module("average_pool", nn.AvgPool3d([5, 6, 5]))
        if dropout_prob is not None:
            self.classifier.add_module("dropout", nn.Dropout(dropout_prob))
        self.classifier.add_module(
            "conv_%d" % n_layer,
            nn.Conv3d(channel_number[-1], output_dim, padding=0, kernel_size=1),
        )

    @staticmethod
    def conv_layer(
        in_channel,
        out_channel,
        maxpool=True,
        kernel_size=3,
        padding=0,
        maxpool_stride=2,
    ):
        if maxpool is True:
            layer = nn.Sequential(
                nn.Conv3d(
                    in_channel, out_channel, padding=padding, kernel_size=kernel_size
                ),
                nn.BatchNorm3d(out_channel),
                nn.MaxPool3d(2, stride=maxpool_stride),
                nn.ReLU(),
            )
        else:
            layer = nn.Sequential(
                nn.Conv3d(
                    in_channel, out_channel, padding=padding, kernel_size=kernel_size
                ),
                nn.BatchNorm3d(out_channel),
                nn.ReLU(),
            )
        return layer

    def forward(self, x):
        x_f = self.feature_extractor(x)
        return self.classifier(x_f).flatten(start_dim=1)


In [3]:
model = SFCN(output_dim=1421)

In [6]:
summary(model, (1, 182,218,182))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 182, 218, 182]             896
       BatchNorm3d-2    [-1, 32, 182, 218, 182]              64
         MaxPool3d-3      [-1, 32, 91, 109, 91]               0
              ReLU-4      [-1, 32, 91, 109, 91]               0
            Conv3d-5      [-1, 64, 91, 109, 91]          55,360
       BatchNorm3d-6      [-1, 64, 91, 109, 91]             128
         MaxPool3d-7       [-1, 64, 45, 54, 45]               0
              ReLU-8       [-1, 64, 45, 54, 45]               0
            Conv3d-9      [-1, 128, 45, 54, 45]         221,312
      BatchNorm3d-10      [-1, 128, 45, 54, 45]             256
        MaxPool3d-11      [-1, 128, 22, 27, 22]               0
             ReLU-12      [-1, 128, 22, 27, 22]               0
           Conv3d-13      [-1, 256, 22, 27, 22]         884,992
      BatchNorm3d-14      [-1, 256, 22,