Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added implementation of UNet3+ #47

Merged
merged 13 commits into from
May 29, 2020
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ conda install -c frgfm pylocron

- Classification: [Res2Net](https://arxiv.org/abs/1904.01169) (based on the great [implementation](https://github.com/gasvn/Res2Net) from gasvn), [darknet24](https://pjreddie.com/media/files/papers/yolo_1.pdf), [darknet19](https://pjreddie.com/media/files/papers/YOLO9000.pdf), [darknet53](https://pjreddie.com/media/files/papers/YOLOv3.pdf).
- Detection: [YOLOv1](https://pjreddie.com/media/files/papers/yolo_1.pdf), [YOLOv2](https://pjreddie.com/media/files/papers/YOLO9000.pdf)
- Segmentation: [U-Net](https://arxiv.org/abs/1505.04597)
- Segmentation: [U-Net](https://arxiv.org/abs/1505.04597), [UNet++](https://arxiv.org/abs/1807.10165), [UNet3+](https://arxiv.org/abs/2004.08790)

### ops

Expand Down
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,5 @@ U-Net
.. autofunction:: unetp

.. autofunction:: unetpp

.. autofunction:: unet3p
257 changes: 176 additions & 81 deletions holocron/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url

from ...nn.init import init_module

__all__ = ['UNet', 'unet', 'UNetp', 'unetp', 'UNetpp', 'unetpp']

__all__ = ['UNet', 'unet', 'UNetp', 'unetp', 'UNetpp', 'unetpp', 'UNet3p', 'unet3p']


default_cfgs = {
Expand All @@ -22,31 +24,42 @@
'layout': [64, 128, 256, 512, 1024],
'url': None},
'unetpp': {'arch': 'UNetpp',
'layout': [64, 128, 256, 512, 1024],
'url': None},
'unet3p': {'arch': 'UNet3p',
'layout': [64, 128, 256, 512, 1024],
'url': None}
}


def conv1x1(in_chan, out_chan):

return nn.Conv2d(in_chan, out_chan, 1)


def conv3x3(in_chan, out_chan, padding=0):

return nn.Conv2d(in_chan, out_chan, 3, padding=padding)


def conv_bn_act(in_chan, out_chan, kernel_size, padding=0, bn=False, act=True):
layers = [nn.Conv2d(in_chan, out_chan, kernel_size, padding=padding)]
if bn:
layers.append(nn.BatchNorm2d(out_chan))
if act:
layers.append(nn.ReLU(inplace=True))

return layers


class DownPath(nn.Sequential):
def __init__(self, in_chan, out_chan, downsample=True):
def __init__(self, in_chan, out_chan, downsample=True, padding=0, bn=False):
layers = [nn.MaxPool2d(2)] if downsample else []
layers.extend([conv3x3(in_chan, out_chan), nn.ReLU(inplace=True),
conv3x3(out_chan, out_chan), nn.ReLU(inplace=True)])
layers.extend([*conv_bn_act(in_chan, out_chan, 3, padding, bn),
*conv_bn_act(out_chan, out_chan, 3, padding, bn)])
super().__init__(*layers)


class UpPath(nn.Module):
def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False):
def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False, padding=0, bn=False):
super().__init__()

if conv_transpose:
Expand All @@ -56,25 +69,25 @@ def __init__(self, in_chan, out_chan, num_skips=1, conv_transpose=False):

# Estimate the number of channels in the upsampled feature map
up_chan = in_chan // 2 if conv_transpose else in_chan
self.block = nn.Sequential(conv3x3(num_skips * in_chan // 2 + up_chan, out_chan),
nn.ReLU(inplace=True),
conv3x3(out_chan, out_chan),
nn.ReLU(inplace=True))
self.block = nn.Sequential(*conv_bn_act(num_skips * in_chan // 2 + up_chan, out_chan, 3, padding, bn),
*conv_bn_act(out_chan, out_chan, 3, padding, bn))
self.num_skips = num_skips

def forward(self, downfeats, upfeat):

if not isinstance(downfeats, list):
downfeats = [downfeats]
if len(downfeats) != self.num_skips:
raise ValueError
raise ValueError(f"Expected {self.num_skips} encoding feats, received {len(downfeats)}")
# Upsample expansive features
_upfeat = self.upsample(upfeat)
# Crop contracting path features
for idx, downfeat in enumerate(downfeats):
delta_w = downfeat.shape[-1] - _upfeat.shape[-1]
w_slice = slice(delta_w // 2, -delta_w // 2 if delta_w > 0 else downfeat.shape[-1])
delta_h = downfeat.shape[-2] - _upfeat.shape[-2]
downfeats[idx] = downfeat[..., delta_h // 2:-delta_h // 2, delta_w // 2:-delta_w // 2]
h_slice = slice(delta_h // 2, -delta_h // 2 if delta_h > 0 else downfeat.shape[-2])
downfeats[idx] = downfeat[..., h_slice, w_slice]
# Concatenate both feature maps and forward them
return self.block(torch.cat((*downfeats, _upfeat), dim=1))

Expand All @@ -91,34 +104,36 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool))
_pool = True

# Expansive path
_layout = layout[::-1]
for num, in_chan, out_chan in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:]):
self.add_module(f"up{num}", UpPath(in_chan, out_chan))
self.decoders = nn.ModuleList([])
for in_chan, out_chan in zip(layout[1:], layout[:-1]):
self.decoders.append(UpPath(in_chan, out_chan))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.down5(x4)
for encoder in self.encoders[:-1]:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))
x = self.encoders[-1](xs[-1])

# Expansive path
x = self.up4(x4, x)
x = self.up3(x3, x)
x = self.up2(x2, x)
x = self.up1(x1, x)
for idx in range(len(self.decoders) - 1, -1, -1):
x = self.decoders[idx](xs[idx], x)
# Release memory
del xs[idx]

# Classifier
x = self.classifier(x)
Expand All @@ -137,48 +152,40 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1))
_pool = True

# Expansive path
_layout = layout[::-1]
for row, in_chan, out_chan, cols in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:],
range(1, len(layout))):
for col in range(1, cols + 1):
self.add_module(f"up{row}{col}", UpPath(in_chan, out_chan))
self.decoders = nn.ModuleList([])
for in_chan, out_chan, idx in zip(layout[1:], layout[:-1], range(len(layout))):
self.decoders.append(nn.ModuleList([UpPath(in_chan, out_chan, padding=1)
for _ in range(len(layout) - idx - 1)]))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x = self.down5(x4)

# Nested Expansive path
x1 = self.up11(x1, x2)
x2 = self.up21(x2, x3)
x3 = self.up31(x3, x4)
x = self.up41(x4, x)

x1 = self.up12(x1, x2)
x2 = self.up22(x2, x3)
x = self.up32(x3, x)

x1 = self.up13(x1, x2)
x = self.up23(x2, x)
for encoder in self.encoders:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))

x = self.up14(x1, x)
# Nested expansive path
for j in range(len(self.decoders)):
for i in range(len(self.decoders) - j):
xs[i] = self.decoders[i][j](xs[i], xs[i + 1])
# Release memory
del xs[len(self.decoders) - j]

# Classifier
x = self.classifier(x)
x = self.classifier(xs[0])
return x


Expand All @@ -194,48 +201,121 @@ def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for num, in_chan, out_chan in zip(range(1, len(_layout)), _layout[:-1], _layout[1:]):
self.add_module(f"down{num}", DownPath(in_chan, out_chan, _pool))
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1))
_pool = True

# Expansive path
_layout = layout[::-1]
for row, in_chan, out_chan, cols in zip(range(len(layout) - 1, 0, -1), _layout[:-1], _layout[1:],
range(1, len(layout))):
for col in range(1, cols + 1):
self.add_module(f"up{row}{col}", UpPath(in_chan, out_chan, num_skips=col))
self.decoders = nn.ModuleList([])
for in_chan, out_chan, idx in zip(layout[1:], layout[:-1], range(len(layout))):
self.decoders.append(nn.ModuleList([UpPath(in_chan, out_chan, num_skips, padding=1)
for num_skips in range(1, len(layout) - idx)]))

# Classifier
self.classifier = conv1x1(64, num_classes)
self.classifier = conv1x1(layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
x10 = self.down1(x)
x20 = self.down2(x10)
x30 = self.down3(x20)
x40 = self.down4(x30)
x = self.down5(x40)
for encoder in self.encoders:
xs.append([encoder(xs[-1][0] if len(xs) > 0 else x)])

# Nested expansive path
for j in range(len(self.decoders)):
for i in range(len(self.decoders) - j):
xs[i].append(self.decoders[i][j](xs[i], xs[i + 1][-1]))
# Release memory
del xs[len(self.decoders) - j]

# Nested Expansive path
x11 = self.up11(x10, x20)
x21 = self.up21(x20, x30)
x31 = self.up31(x30, x40)
x = self.up41(x40, x)
# Classifier
x = self.classifier(xs[0][-1])
return x


class FSAggreg(nn.Module):
def __init__(self, e_chans, skip_chan, d_chans):
super().__init__()
# Check stem conv channels
base_chan = e_chans[0] if len(e_chans) > 0 else skip_chan
# Get UNet depth
depth = len(e_chans) + 1 + len(d_chans)
# Downsample = max pooling + conv for channel reduction
self.downsamples = nn.ModuleList([nn.Sequential(nn.MaxPool2d(2 ** (len(e_chans) - idx)),
conv3x3(e_chan, base_chan, 1))
for idx, e_chan in enumerate(e_chans)])
self.skip = conv3x3(skip_chan, base_chan, 1) if len(e_chans) > 0 else nn.Identity()
# Upsample = bilinear interpolation + conv for channel reduction
self.upsamples = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2 ** (idx + 1),
mode='bilinear', align_corners=True),
conv3x3(d_chan, base_chan, 1))
for idx, d_chan in enumerate(d_chans)])

self.block = nn.Sequential(*conv_bn_act(depth * base_chan, depth * base_chan, 3, 1, True))

def forward(self, downfeats, feat, upfeats):

if len(downfeats) != len(self.downsamples) or len(upfeats) != len(self.upsamples):
raise ValueError(f"Expected {len(self.downsamples)} encoding & {len(self.upsamples)} decoding features, "
f"received: {len(downfeats)} & {len(upfeats)}")

x12 = self.up12([x10, x11], x21)
x22 = self.up22([x20, x21], x31)
x = self.up32([x30, x31], x)
# Concatenate full-scale features
x = torch.cat((*[downsample(downfeat) for downsample, downfeat in zip(self.downsamples, downfeats)],
self.skip(feat),
*[upsample(upfeat) for upsample, upfeat in zip(self.upsamples, upfeats)]), dim=1)

x13 = self.up13([x10, x11, x12], x22)
x = self.up23([x20, x21, x22], x)
return self.block(x)

x = self.up14([x10, x11, x12, x13], x)

class UNet3p(nn.Module):
"""Implements a UNet3+ architecture

Args:
layout (list<int>): number of channels after each contracting block
in_channels (int, optional): number of channels in the input tensor
num_classes (int, optional): number of output classes
"""
def __init__(self, layout, in_channels=1, num_classes=10):
super().__init__()

# Contracting path
self.encoders = nn.ModuleList([])
_layout = [in_channels] + layout
_pool = False
for in_chan, out_chan in zip(_layout[:-1], _layout[1:]):
self.encoders.append(DownPath(in_chan, out_chan, _pool, 1, True))
_pool = True

# Expansive path
self.decoders = nn.ModuleList([])
for row in range(len(layout) - 1):
self.decoders.append(FSAggreg(layout[:row],
layout[row],
[len(layout) * layout[0]] * (len(layout) - 2 - row) + layout[-1:]))

# Classifier
x = self.classifier(x)
self.classifier = conv1x1(len(layout) * layout[0], num_classes)

init_module(self, 'relu')

def forward(self, x):

xs = []
# Contracting path
for encoder in self.encoders:
xs.append(encoder(xs[-1] if len(xs) > 0 else x))

# Full-scale expansive path
for idx in range(len(self.decoders) - 1, -1, -1):
xs[idx] = self.decoders[idx](xs[:idx], xs[idx], xs[idx + 1:])

# Classifier
x = self.classifier(xs[0])
return x


Expand Down Expand Up @@ -299,3 +379,18 @@ def unetpp(pretrained=False, progress=True, **kwargs):
"""

return _unet('unetpp', pretrained, progress, **kwargs)


def unet3p(pretrained=False, progress=True, **kwargs):
"""UNet3+ from
`"UNet 3+: A Full-Scale Connected UNet For Medical Image Segmentation" <https://arxiv.org/pdf/2004.08790.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr

Returns:
torch.nn.Module: semantic segmentation model
"""

return _unet('unet3p', pretrained, progress, **kwargs)
Loading