# 3D ResNet


In [None]:
# hide
import sys
sys.path.append("..")
from nbdev.showdoc import *

In [None]:
# default_exp models.resnet
# export
from fastai.basics import *
from fastai.layers import *
from warnings import warn
from torch.hub import load_state_dict_from_url

## ResNet 3D

Same code as the ResNet implementation on torchvision, just replacing 2D modules with 3D modules

### Building blocks

In [None]:
# export
from torchvision.models.resnet import Bottleneck, BasicBlock

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

from torch import nn # prevent error in nbdev with re-importing nn (was already imported with fastai)
class BasicBlock3d(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, act_layer=None):
        super(BasicBlock3d, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm3d
        if act_layer is None:
            act_layer = partial(nn.ReLU, inplace=True)
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = act_layer()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck3d(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, act_layer=None):
        super(Bottleneck3d, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm3d
        if act_layer is None:
            act_layer = partial(nn.ReLU, inplace=True)
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = act_layer()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)


        return out

In [None]:
BasicBlock3d(4, 64, norm_layer=partial(nn.BatchNorm3d, affine = False))

BasicBlock3d(
  (conv1): Conv3d(4, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
)

In [None]:
Bottleneck3d(4, 64)

Bottleneck3d(
  (conv1): Conv3d(4, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

### Identity layer

Medical images, especially 3D images are large so batch size is limited when trainig in normal consumer hardware. This can lead to problems with the normalization layers, as performance can/will decrease for batch sizes under 32/under 8. 
This is discussed [here](https://stackoverflow.com/questions/59648509/batch-normalization-when-batch-size-1), [here](https://luminovo.ai/blog/a-refresher-on-batch-re-normalization) and [here](https://www.alexirpan.com/2017/04/26/perils-batch-norm.html). Replacing the normalization layer with an indentity layer might be a quick solution without the need to alter the whole architecture. 

In [None]:
# export
class IdentityLayer(nn.Module):
    "Returns input as is"
    def __init__(self, *args, **kwargs):
        super(IdentityLayer, self).__init__()
    def forward(self, x):
        return x

### ResNet basic module
Same as the ResNet module from torchvision, but all 2D submodules have been changed to 3D, and MaxPool has a kernel size of (1,3,3), to avoid reduction of the depth to 1 (depth of medical images can be very small).  

In [None]:
# export
class ResNet3D(nn.Module):

    def __init__(self, block, layers, n_channels=3, num_classes=101, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, act_layer=None, final_softmax=False, ps = 0.5):
        super(ResNet3D, self).__init__()
        if norm_layer is None: norm_layer = nn.BatchNorm3d
        if act_layer is None: act_layer = partial(nn.ReLU, inplace=True)
        self._norm_layer = norm_layer
        self.inplanes = 128 if isinstance(block(1,1), Bottleneck3d) else 32

        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        self.stem = nn.Sequential(nn.Conv3d(n_channels, self.inplanes, kernel_size=(2, 5, 5), stride=(1, 3, 3), padding=1, bias=False),
                                  norm_layer(self.inplanes),
                                  act_layer(inplace=True))

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Sequential(
            nn.BatchNorm1d(512 * block.expansion, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Dropout(p=ps/2, inplace=False),
            nn.Linear(512 * block.expansion, 256),
            act_layer(inplace=True),
            nn.BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.Dropout(p=ps, inplace=False),
            nn.Linear(256, num_classes,bias = False))

        if final_softmax:
            self.fc = nn.Sequential(self.fc,
                                    nn.Softmax(1))

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck3d):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock3d):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _encoder(self, x1):
        x2 = self.layer1(x1)
        x3 = self.layer2(x2)
        x4 = self.layer3(x3)
        x5 = self.layer4(x4)
        return x2, x3, x4, x5

    def _head(self, x5):
        x = self.avgpool(x5)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)

In [None]:
ResNet3D(BasicBlock3d, [2, 2, 2, 2], final_softmax = True, act_layer=nn.LeakyReLU, ps = 0.75)(torch.randn(10, 3, 8, 64, 64)).size()

torch.Size([10, 101])

## ResNet architectures

Note that a pretrained ResNet18 for 3D already exists at `torchvision.models.video`

In [None]:
model_urls = {
           'resnet18_3d': 'https://rad-ai.charite.de/pretrained_models/resnet18_3d.pth', 
           'resnet34_3d': 'https://rad-ai.charite.de/pretrained_models/resnet34_3d.pth', 
           'resnet50_3d': 'https://rad-ai.charite.de/pretrained_models/resnet50_3d.pth', 
           'resnet101_3d': 'https://rad-ai.charite.de/pretrained_models/resnet101_3d.pth'
          }

In [None]:
# export

def _resnet_3d(arch, block, layers, pretrained=False, progress=False, **kwargs):
    "similar to the _resnet function of pytorch. Has same Args as resnet for compatibility, but does not us them all"
    model = ResNet3D(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=True)
        model.load_state_dict(state_dict['model'])
    return model

In [None]:
# export
def resnet18_3d(pretrained=False, progress=False, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    adapted to 3d
    """
    return _resnet_3d('resnet18_3d', BasicBlock3d, [2, 2, 2, 2], pretrained=pretrained, progress=progress,**kwargs)


def resnet34_3d(pretrained=False, progress=False, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    adapted to 3d
    """
    return _resnet_3d('resnet34_3d', BasicBlock3d, [3, 4, 6, 3], pretrained=pretrained, progress=progress,**kwargs)


def resnet50_3d(pretrained=False, progress=False,**kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    adapted to 3d
    """
    return _resnet_3d('resnet50_3d', Bottleneck3d, [3, 4, 6, 3], pretrained=pretrained, progress=progress,**kwargs)


def resnet101_3d(pretrained=False, progress=False, **kwargs):
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
    adapted to 3d
    """
    return _resnet_3d('resnet101_3d', Bottleneck3d, [3, 4, 23, 3], pretrained=pretrained, progress=progress,**kwargs)


def resnet152_3d(pretrained=False, progress=False, **kwargs):
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`
    adapted to 3d
    """
    # currently no pretrained version. Might follow in the future
    if pretrained: warn('Currently there is no pretrained version available for `resnet152_3d`. Will load randomly intilialized weights.')
    return _resnet_3d(None, Bottleneck3d, [3, 8, 36, 3], pretrained=False, progress=False,**kwargs)

In [None]:
model = resnet18_3d()
input = torch.rand(2, 3, 15, 80, 80)
output = model(input)
print(output.size())

torch.Size([2, 101])


In [None]:
model = resnet101_3d()
input = torch.rand(2, 3, 15, 64, 64)
output = model(input)
print(output.size())

torch.Size([2, 100])


### Resnet encoder

for UNet or DeepLabV3

In [None]:
# export
def build_backbone(backbone, output_stride, norm_layer, n_channels, **kwargs):
    model = backbone(n_channels=n_channels, norm_layer=norm_layer, **kwargs) #output_stride, BatchNorm)
    def forward(x):
        x1=model.stem(x)
        x2=model.layer1(x1)
        x3=model.layer2(x2)
        x4=model.layer3(x3)
        x5=model.layer4(x4)
        return x1, x2, x3, x4, x5
    model.forward = forward
    return model

In [None]:
m = build_backbone(resnet34_3d, 8, IdentityLayer, 5)
xb = m(torch.randn(10, 5, 10, 50, 50))
for x in xb: print(x.size())

torch.Size([10, 32, 11, 16, 16])
torch.Size([10, 64, 11, 16, 16])
torch.Size([10, 128, 6, 8, 8])
torch.Size([10, 256, 3, 4, 4])
torch.Size([10, 512, 2, 2, 2])


In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 01_basics.ipynb.
Converted 02_preprocessing.ipynb.
Converted 03_transforms.ipynb.
Converted 04_dataloaders.ipynb.
Converted 05_layers.ipynb.
Converted 06_learner.ipynb.
Converted 06a_models.alexnet.ipynb.
Converted 06b_models.resnet.ipynb.
Converted 06d_models.unet.ipynb.
Converted 06f_models.losses.ipynb.
Converted 07_callback.ipynb.
Converted index.ipynb.
