From f00b262418f936e7830f0c88a828af4f6cc1d552 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 3 Feb 2024 18:31:26 +0100 Subject: [PATCH 1/3] Update UNETR - to use bilinear interpolation for upsampling --- torch_em/model/unetr.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 13b570b2..a460ae96 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -138,7 +138,9 @@ def __init__( self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1]) + self.deconv_out = Upsampler2d( + scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] + ) self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) @@ -274,15 +276,6 @@ def forward(self, x): # -class SingleDeconv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes): - super().__init__() - self.block = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0) - - def forward(self, x): - return self.block(x) - - class SingleConv2DBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size): super().__init__() @@ -310,7 +303,7 @@ class Deconv2DBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size=3): super().__init__() self.block = nn.Sequential( - SingleDeconv2DBlock(in_planes, out_planes), + Upsampler2d(scale_factor=2, in_channels=in_planes, out_channels=out_planes), SingleConv2DBlock(out_planes, out_planes, kernel_size), nn.BatchNorm2d(out_planes), nn.ReLU(True) From 87491c89dd102745a8f41b24aa52850deb01e00f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 3 Feb 2024 20:39:11 +0100 Subject: [PATCH 2/3] Update deconv blocks to avoid sharing parameters --- torch_em/model/unetr.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index a460ae96..d8c70907 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -127,17 +127,28 @@ def __init__( else: self.decoder = decoder - self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1]) + if use_skip_connection: + self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) + self.deconv2 = nn.Sequential( + Deconv2DBlock(embed_dim, features_decoder[0]), + Deconv2DBlock(features_decoder[0], features_decoder[1]) + ) + self.deconv3 = nn.Sequential( + Deconv2DBlock(embed_dim, features_decoder[0]), + Deconv2DBlock(features_decoder[0], features_decoder[1]), + Deconv2DBlock(features_decoder[1], features_decoder[2]) + ) + self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) + else: + self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) + self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) + self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) + self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) self.base = ConvBlock2d(embed_dim, features_decoder[0]) self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) - self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) - self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) - self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) - self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - self.deconv_out = Upsampler2d( scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] ) @@ -235,18 +246,11 @@ def forward(self, x): z12 = encoder_outputs if use_skip_connection: - # TODO: we share the weights in the deconv(s), and should preferably avoid doing that from_encoder = from_encoder[::-1] z9 = self.deconv1(from_encoder[0]) - - z6 = self.deconv1(from_encoder[1]) - z6 = self.deconv2(z6) - - z3 = self.deconv1(from_encoder[2]) - z3 = self.deconv2(z3) - z3 = self.deconv3(z3) - - z0 = self.z_inputs(x) + z6 = self.deconv2(from_encoder[1]) + z3 = self.deconv3(from_encoder[2]) + z0 = self.deconv4(x) else: z9 = self.deconv1(z12) From bda7604233c286b29c9aa9d81f39ba9f186663d1 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 6 Feb 2024 19:42:03 +0100 Subject: [PATCH 3/3] Make UNETR upsampling choice modular (default: conv transpose) --- torch_em/model/unetr.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index d8c70907..fbb0e085 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -72,6 +72,7 @@ def __init__( final_activation: Optional[Union[str, nn.Module]] = None, use_skip_connection: bool = True, embed_dim: Optional[int] = None, + use_conv_transpose=True, ) -> None: super().__init__() @@ -117,12 +118,15 @@ def __init__( scale_factors = depth * [2] self.out_channels = out_channels + # choice of upsampler - to use (bilinear interpolation + conv) or conv transpose + _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d + if decoder is None: self.decoder = Decoder( features=features_decoder, scale_factors=scale_factors[::-1], conv_block_impl=ConvBlock2d, - sampler_impl=Upsampler2d + sampler_impl=_upsampler ) else: self.decoder = decoder @@ -149,7 +153,7 @@ def __init__( self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) - self.deconv_out = Upsampler2d( + self.deconv_out = _upsampler( scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] ) @@ -280,22 +284,32 @@ def forward(self, x): # +class SingleDeconv2DBlock(nn.Module): + def __init__(self, scale_factor, in_channels, out_channels): + super().__init__() + self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0) + + def forward(self, x): + return self.block(x) + + class SingleConv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size): + def __init__(self, in_channels, out_channels, kernel_size): super().__init__() - self.block = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=1, - padding=((kernel_size - 1) // 2)) + self.block = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2) + ) def forward(self, x): return self.block(x) class Conv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3): + def __init__(self, in_channels, out_channels, kernel_size=3): super().__init__() self.block = nn.Sequential( - SingleConv2DBlock(in_planes, out_planes, kernel_size), - nn.BatchNorm2d(out_planes), + SingleConv2DBlock(in_channels, out_channels, kernel_size), + nn.BatchNorm2d(out_channels), nn.ReLU(True) ) @@ -304,12 +318,13 @@ def forward(self, x): class Deconv2DBlock(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size=3): + def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True): super().__init__() + _upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d self.block = nn.Sequential( - Upsampler2d(scale_factor=2, in_channels=in_planes, out_channels=out_planes), - SingleConv2DBlock(out_planes, out_planes, kernel_size), - nn.BatchNorm2d(out_planes), + _upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels), + SingleConv2DBlock(out_channels, out_channels, kernel_size), + nn.BatchNorm2d(out_channels), nn.ReLU(True) )