From 3a443d22f6ca73dc307025fccc4d8b821c04aed2 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:33:24 +0100 Subject: [PATCH] Update UNETR upsampling (#211) Use ConvTranspose in UNETR implementation. --- torch_em/model/unetr.py | 72 ++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 13b570b2..fbb0e085 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -72,6 +72,7 @@ def __init__( final_activation: Optional[Union[str, nn.Module]] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, + use_conv_transpose=True, ) -> None: super().__init__() @@ -117,28 +118,44 @@ def __init__( scale_factors = depth * [2] self.out_channels = out_channels + # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose + _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d + if decoder is None: self.decoder = Decoder( features=features_decoder, scale_factors=scale_factors[::-1], conv_block_impl=ConvBlock2d, - sampler_impl=Upsampler2d + sampler_impl=_upsampler ) else: self.decoder = decoder - self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1]) + if use_skip_connection: + self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) + self.deconv2 = nn.Sequential( + Deconv2DBlock(embed_dim, features_decoder[0]), + Deconv2DBlock(features_decoder[0], features_decoder[1]) + ) + self.deconv3 = nn.Sequential( + Deconv2DBlock(embed_dim, features_decoder[0]), + Deconv2DBlock(features_decoder[0], features_decoder[1]), + Deconv2DBlock(features_decoder[1], features_decoder[2]) + ) + self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) + else: + self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) + self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) + self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) + self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) self.base = ConvBlock2d(embed_dim, features_decoder[0]) self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) - self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) - self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) - self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) - self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - - self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) + self.deconv_out = _upsampler( + scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] + ) self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) @@ -233,18 +250,11 @@ def forward(self, x): z12 = encoder_outputs if use_skip_connection: - # TODO: we share the weights in the deconv(s), and should preferably avoid doing that from_encoder = from_encoder[::-1] z9 = self.deconv1(from_encoder[0]) - - z6 = self.deconv1(from_encoder[1]) - z6 = self.deconv2(z6) - - z3 = self.deconv1(from_encoder[2]) - z3 = self.deconv2(z3) - z3 = self.deconv3(z3) - - z0 = self.z_inputs(x) + z6 = self.deconv2(from_encoder[1]) + z3 = self.deconv3(from_encoder[2]) + z0 = self.deconv4(x) else: z9 = self.deconv1(z12) @@ -275,30 +285,31 @@ def forward(self, x): class SingleDeconv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes): + def __init__(self, scale_factor, in_channels, out_channels): super().__init__() - self.block = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) + self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) def forward(self, x): return self.block(x) class SingleConv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size): + def __init__(self, in_channels, out_channels, kernel_size): super().__init__() - self.block = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=1, - padding=((kernel_size - 1) // 2)) + self.block = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) + ) def forward(self, x): return self.block(x) class Conv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3): + def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() self.block = nn.Sequential( - SingleConv2DBlock(in_planes, out_planes, kernel_size), - nn.BatchNorm2d(out_planes), + SingleConv2DBlock(in_channels, out_channels, kernel_size), + nn.BatchNorm2d(out_channels), nn.ReLU(True) ) @@ -307,12 +318,13 @@ def forward(self, x): class Deconv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3): + def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): super().__init__() + _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d self.block = nn.Sequential( - SingleDeconv2DBlock(in_planes, out_planes), - SingleConv2DBlock(out_planes, out_planes, kernel_size), - nn.BatchNorm2d(out_planes), + _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), + SingleConv2DBlock(out_channels, out_channels, kernel_size), + nn.BatchNorm2d(out_channels), nn.ReLU(True) )