Skip to content

Commit

Permalink
Update UNETR upsampling (#211)
Browse files Browse the repository at this point in the history
Use ConvTranspose in UNETR implementation.
  • Loading branch information
anwai98 committed Feb 8, 2024
1 parent 81fb019 commit 3a443d2
Showing 1 changed file with 42 additions and 30 deletions.
72 changes: 42 additions & 30 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None,
use_conv_transpose=True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -117,28 +118,44 @@ def __init__(
scale_factors = depth * [2]
self.out_channels = out_channels

# choice of upsampler - to use (bilinear interpolation + conv) or conv transpose
_upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d

if decoder is None:
self.decoder = Decoder(
features=features_decoder,
scale_factors=scale_factors[::-1],
conv_block_impl=ConvBlock2d,
sampler_impl=Upsampler2d
sampler_impl=_upsampler
)
else:
self.decoder = decoder

self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1])
if use_skip_connection:
self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = nn.Sequential(
Deconv2DBlock(embed_dim, features_decoder[0]),
Deconv2DBlock(features_decoder[0], features_decoder[1])
)
self.deconv3 = nn.Sequential(
Deconv2DBlock(embed_dim, features_decoder[0]),
Deconv2DBlock(features_decoder[0], features_decoder[1]),
Deconv2DBlock(features_decoder[1], features_decoder[2])
)
self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1])
else:
self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
self.deconv_out = _upsampler(
scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1]
)

self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

Expand Down Expand Up @@ -233,18 +250,11 @@ def forward(self, x):
z12 = encoder_outputs

if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)

z0 = self.z_inputs(x)
z6 = self.deconv2(from_encoder[1])
z3 = self.deconv3(from_encoder[2])
z0 = self.deconv4(x)

else:
z9 = self.deconv1(z12)
Expand Down Expand Up @@ -275,30 +285,31 @@ def forward(self, x):


class SingleDeconv2DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
def __init__(self, scale_factor, in_channels, out_channels):
super().__init__()
self.block = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)
self.block = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)

def forward(self, x):
return self.block(x)


class SingleConv2DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size):
def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
self.block = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
padding=((kernel_size - 1) // 2))
self.block = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=((kernel_size - 1) // 2)
)

def forward(self, x):
return self.block(x)


class Conv2DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
self.block = nn.Sequential(
SingleConv2DBlock(in_planes, out_planes, kernel_size),
nn.BatchNorm2d(out_planes),
SingleConv2DBlock(in_channels, out_channels, kernel_size),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)

Expand All @@ -307,12 +318,13 @@ def forward(self, x):


class Deconv2DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3):
def __init__(self, in_channels, out_channels, kernel_size=3, use_conv_transpose=True):
super().__init__()
_upsampler = SingleDeconv2DBlock if use_conv_transpose else Upsampler2d
self.block = nn.Sequential(
SingleDeconv2DBlock(in_planes, out_planes),
SingleConv2DBlock(out_planes, out_planes, kernel_size),
nn.BatchNorm2d(out_planes),
_upsampler(scale_factor=2, in_channels=in_channels, out_channels=out_channels),
SingleConv2DBlock(out_channels, out_channels, kernel_size),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)

Expand Down

0 comments on commit 3a443d2

Please sign in to comment.