Skip to content

Commit

Permalink
feat: Added implementation of UNet3+ (#47)
Browse files Browse the repository at this point in the history
* refactor: Refactored UNet components

* feat: Added UNet parameter initialization

* feat: Added implementation of UNet3+

* test: Added unittest for UNet3+

* docs: Added UNet3+ to documentation

* docs: Updated README

* refactor: Refactored UNets to allow layout flexibility

* refactor: Refactored Unet forward

* refactor: Improved memory efficiency of UNet

* fix: Fixed UpPath cropping

* fix: Fixed padding for Unet+ and Unet++

* fix: Fixed cropping in skip connections

* test: Lowered input size to avoid OOM with unets
  • Loading branch information
frgfm committed May 29, 2020
1 parent 0a56cae commit 511c1b1
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 83 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ conda install -c frgfm pylocron

- Classification: [Res2Net](https://arxiv.org/abs/1904.01169) (based on the great [implementation](https://github.com/gasvn/Res2Net) from gasvn), [darknet24](https://pjreddie.com/media/files/papers/yolo_1.pdf), [darknet19](https://pjreddie.com/media/files/papers/YOLO9000.pdf), [darknet53](https://pjreddie.com/media/files/papers/YOLOv3.pdf).
- Detection: [YOLOv1](https://pjreddie.com/media/files/papers/yolo_1.pdf), [YOLOv2](https://pjreddie.com/media/files/papers/YOLO9000.pdf)
- Segmentation: [U-Net](https://arxiv.org/abs/1505.04597)
- Segmentation: [U-Net](https://arxiv.org/abs/1505.04597), [UNet++](https://arxiv.org/abs/1807.10165), [UNet3+](https://arxiv.org/abs/2004.08790)

### ops

Expand Down
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ U-Net
.. autofunction:: unetp

.. autofunction:: unetpp

.. autofunction:: unet3p
257 changes: 176 additions & 81 deletions holocron/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url

from ...nn.init import init_module

__all__ = ['UNet', 'unet', 'UNetp', 'unetp', 'UNetpp', 'unetpp']

__all__ = ['UNet', 'unet', 'UNetp', 'unetp', 'UNetpp', 'unetpp', 'UNet3p', 'unet3p']


default_cfgs = {
Expand All @@ -22,31 +24,42 @@
'layout': [64, 128, 256, 512, 1024],
'url': None},
'unetpp': {'arch': 'UNetpp',
'layout': [64, 128, 256, 512, 1024],
'url': None},
'unet3p': {'arch': 'UNet3p',
'layout': [64, 128, 256, 512, 1024],
'url': None}
}


def conv1x1(in_chan, out_chan):

return nn.Conv2d(in_chan, out_chan, 1)


def conv3x3(in_chan, out_chan, padding=0):

return nn.Conv2d(in_chan, out_chan, 3, padding=padding)


def conv_bn_act(in_chan, out_chan, kernel_size, padding=0, bn=False, act=True):
layers = [nn.Conv2d(in_chan, out_chan, kernel_size, padding=padding)]
if bn:
layers.append(nn.BatchNorm2d(out_chan))
if act:
layers.append(nn.ReLU(inplace=True))

return layers


class DownPath(nn.Sequential):
def __init__(self, in_chan, out_chan, downsample=True):
def __init__(self, in_chan, out_chan, downsample=True, padding=0, bn=False):
layers = [nn.MaxPool2d(2)] if downsample else []
layers.extend([conv3x3(in_chan, out_chan), nn.ReLU(inplace=True),
conv3x3(out_chan, out_chan), nn.ReLU(inplace=True)])
layers.extend([*conv_bn_act(in_chan, out_chan, 3, padding, bn),
*conv_bn_act(out_chan, out_chan, 3, padding, bn)])
super().__init__(*layers)


class UpPath(nn.Module):
def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False):
def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False, padding=0, bn=False):
super().__init__()

if conv_transpose:
Expand All @@ -56,25 +69,25 @@ def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False):

# Estimate the number of channels in the upsampled feature map
up_chan = in_chan // 2 if conv_transpose else in_chan
self.block = nn.Sequential(conv3x3(num_skips * in_chan // 2 + up_chan, out_chan),
nn.ReLU(inplace=True),
conv3x3(out_chan, out_chan),
nn.ReLU(inplace=True))
self.block = nn.Sequential(*conv_bn_act(num_skips * in_chan // 2 + up_chan, out_chan, 3, padding, bn),
*conv_bn_act(out_chan, out_chan, 3, padding, bn))
self.num_skips = num_skips

def forward(self, downfeats, upfeat):

if not isinstance(downfeats, list):
downfeats = [downfeats]
if len(downfeats) != self.num_skips:
raise ValueError
raise ValueError(f"Expected {self.num_skips} encoding feats, received {len(downfeats)}")
# Upsample expansive features
_upfeat = self.upsample(upfeat)
# Crop contracting path features
for idx, downfeat in enumerate(downfeats):
delta_w = downfeat.shape[-1] - _upfeat.shape[-1]
w_slice = slice(delta_w // 2, -delta_w // 2 if delta_w > 0 else downfeat.shape[-1])
delta_h = downfeat.shape[-2] - _upfeat.shape[-2]
downfeats[idx] = downfeat[..., delta_h // 2:-delta_h // 2, delta_w // 2:-delta_w // 2]
h_slice = slice(delta_h // 2, -delta_h // 2 if delta_h > 0 else downfeat.shape[-2])
downfeats[idx] = downfeat[..., h_slice, w_slice]
# Concatenate both feature maps and forward them
return self.block(torch.cat((*downfeats, _upfeat), dim=1))

Expand All @@ -91,34 +104,36 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool))
_pool = True

# Expansive path
_layout = layout[::-1]
for num, in_chan, out_chan in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:]):
self.add_module(f"up{num}", UpPath(in_chan, out_chan))
self.decoders = nn.ModuleList([])
for in_chan, out_chan in zip(layout[1:], layout[:-1]):
self.decoders.append(UpPath(in_chan, out_chan))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.down5(x4)
for encoder in self.encoders[:-1]:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))
x = self.encoders[-1](xs[-1])

# Expansive path
x = self.up4(x4, x)
x = self.up3(x3, x)
x = self.up2(x2, x)
x = self.up1(x1, x)
for idx in range(len(self.decoders) - 1, -1, -1):
x = self.decoders[idx](xs[idx], x)
# Release memory
del xs[idx]

# Classifier
x = self.classifier(x)
Expand All @@ -137,48 +152,40 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1))
_pool = True

# Expansive path
_layout = layout[::-1]
for row, in_chan, out_chan, cols in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:],
range(1, len(layout))):
for col in range(1, cols + 1):
self.add_module(f"up{row}{col}", UpPath(in_chan, out_chan))
self.decoders = nn.ModuleList([])
for in_chan, out_chan, idx in zip(layout[1:], layout[:-1], range(len(layout))):
self.decoders.append(nn.ModuleList([UpPath(in_chan, out_chan, padding=1)
for _ in range(len(layout) - idx - 1)]))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.down5(x4)

# Nested Expansive path
x1 = self.up11(x1, x2)
x2 = self.up21(x2, x3)
x3 = self.up31(x3, x4)
x = self.up41(x4, x)

x1 = self.up12(x1, x2)
x2 = self.up22(x2, x3)
x = self.up32(x3, x)

x1 = self.up13(x1, x2)
x = self.up23(x2, x)
for encoder in self.encoders:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))

x = self.up14(x1, x)
# Nested expansive path
for j in range(len(self.decoders)):
for i in range(len(self.decoders) - j):
xs[i] = self.decoders[i][j](xs[i], xs[i + 1])
# Release memory
del xs[len(self.decoders) - j]

# Classifier
x = self.classifier(x)
x = self.classifier(xs[0])
return x


Expand All @@ -194,48 +201,121 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1))
_pool = True

# Expansive path
_layout = layout[::-1]
for row, in_chan, out_chan, cols in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:],
range(1, len(layout))):
for col in range(1, cols + 1):
self.add_module(f"up{row}{col}", UpPath(in_chan, out_chan, num_skips=col))
self.decoders = nn.ModuleList([])
for in_chan, out_chan, idx in zip(layout[1:], layout[:-1], range(len(layout))):
self.decoders.append(nn.ModuleList([UpPath(in_chan, out_chan, num_skips, padding=1)
for num_skips in range(1, len(layout) - idx)]))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x10 = self.down1(x)
x20 = self.down2(x10)
x30 = self.down3(x20)
x40 = self.down4(x30)
x = self.down5(x40)
for encoder in self.encoders:
xs.append([encoder(xs[-1][0] if len(xs) > 0 else x)])

# Nested expansive path
for j in range(len(self.decoders)):
for i in range(len(self.decoders) - j):
xs[i].append(self.decoders[i][j](xs[i], xs[i + 1][-1]))
# Release memory
del xs[len(self.decoders) - j]

# Nested Expansive path
x11 = self.up11(x10, x20)
x21 = self.up21(x20, x30)
x31 = self.up31(x30, x40)
x = self.up41(x40, x)
# Classifier
x = self.classifier(xs[0][-1])
return x


class FSAggreg(nn.Module):
def __init__(self, e_chans, skip_chan, d_chans):
super().__init__()
# Check stem conv channels
base_chan = e_chans[0] if len(e_chans) > 0 else skip_chan
# Get UNet depth
depth = len(e_chans) + 1 + len(d_chans)
# Downsample = max pooling + conv for channel reduction
self.downsamples = nn.ModuleList([nn.Sequential(nn.MaxPool2d(2 ** (len(e_chans) - idx)),
conv3x3(e_chan, base_chan, 1))
for idx, e_chan in enumerate(e_chans)])
self.skip = conv3x3(skip_chan, base_chan, 1) if len(e_chans) > 0 else nn.Identity()
# Upsample = bilinear interpolation + conv for channel reduction
self.upsamples = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2 ** (idx + 1),
mode='bilinear', align_corners=True),
conv3x3(d_chan, base_chan, 1))
for idx, d_chan in enumerate(d_chans)])

self.block = nn.Sequential(*conv_bn_act(depth * base_chan, depth * base_chan, 3, 1, True))

def forward(self, downfeats, feat, upfeats):

if len(downfeats) != len(self.downsamples) or len(upfeats) != len(self.upsamples):
raise ValueError(f"Expected {len(self.downsamples)} encoding & {len(self.upsamples)} decoding features, "
f"received: {len(downfeats)} & {len(upfeats)}")

x12 = self.up12([x10, x11], x21)
x22 = self.up22([x20, x21], x31)
x = self.up32([x30, x31], x)
# Concatenate full-scale features
x = torch.cat((*[downsample(downfeat) for downsample, downfeat in zip(self.downsamples, downfeats)],
self.skip(feat),
*[upsample(upfeat) for upsample, upfeat in zip(self.upsamples, upfeats)]), dim=1)

x13 = self.up13([x10, x11, x12], x22)
x = self.up23([x20, x21, x22], x)
return self.block(x)

x = self.up14([x10, x11, x12, x13], x)

class UNet3p(nn.Module):
"""Implements a UNet3+ architecture
Args:
layout (list<int>): number of channels after each contracting block
in_channels (int, optional): number of channels in the input tensor
num_classes (int, optional): number of output classes
"""
def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1, True))
_pool = True

# Expansive path
self.decoders = nn.ModuleList([])
for row in range(len(layout) - 1):
self.decoders.append(FSAggreg(layout[:row],
layout[row],
[len(layout) * layout[0]] * (len(layout) - 2 - row) + layout[-1:]))

# Classifier
x = self.classifier(x)
self.classifier = conv1x1(len(layout) * layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
for encoder in self.encoders:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))

# Full-scale expansive path
for idx in range(len(self.decoders) - 1, -1, -1):
xs[idx] = self.decoders[idx](xs[:idx], xs[idx], xs[idx + 1:])

# Classifier
x = self.classifier(xs[0])
return x


Expand Down Expand Up @@ -299,3 +379,18 @@ def unetpp(pretrained=False, progress=True, **kwargs):
"""

return _unet('unetpp', pretrained, progress, **kwargs)


def unet3p(pretrained=False, progress=True, **kwargs):
"""UNet3+ from
`"UNet 3+: A Full-Scale Connected UNet For Medical Image Segmentation" <https://arxiv.org/pdf/2004.08790.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Returns:
torch.nn.Module: semantic segmentation model
"""

return _unet('unet3p', pretrained, progress, **kwargs)
Loading

0 comments on commit 511c1b1

Please sign in to comment.