Skip to content

Commit

Permalink
feat: Added StackUpsample2d (#132)
Browse files Browse the repository at this point in the history
* feat: Added stack upsampling

* test: Added unittest for functional form

* feat: Added upsample modules

* feat: Added module version

* test: Extended unittest

* docs: Added module to documentation

* docs: Updated documentation

* style: Fixed lint
  • Loading branch information
frgfm committed May 9, 2021
1 parent de25835 commit 6dc9448
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ Downsampling
.. autofunction:: concat_downsample2d
.. autofunction:: z_pool

Upsampling
----------

.. autofunction:: stack_upsample2d
6 changes: 6 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ Attention
.. autoclass:: LambdaLayer

.. autoclass:: TripletAttention


Upsampling
------------

.. autoclass:: StackUpsample2d
12 changes: 12 additions & 0 deletions holocron/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def unet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UNet
"""U-Net from
`"U-Net: Convolutional Networks for Biomedical Image Segmentation" <https://arxiv.org/pdf/1505.04597.pdf>`_
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unet.png
:align: center
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -449,6 +452,9 @@ def unetp(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UNe
"""UNet+ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation"
<https://arxiv.org/pdf/1912.05074.pdf>`_
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetp.png
:align: center
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -464,6 +470,9 @@ def unetpp(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UN
"""UNet++ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation"
<https://arxiv.org/pdf/1912.05074.pdf>`_
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetpp.png
:align: center
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -479,6 +488,9 @@ def unet3p(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UN
"""UNet3+ from
`"UNet 3+: A Full-Scale Connected UNet For Medical Image Segmentation" <https://arxiv.org/pdf/2004.08790.pdf>`_
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unet3p.png
:align: center
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand Down
31 changes: 30 additions & 1 deletion holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


__all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool']
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool',
'concat_downsample2d', 'stack_upsample2d']


def silu(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -154,6 +155,34 @@ def concat_downsample2d(x: Tensor, scale_factor: int) -> Tensor:
return x


def stack_upsample2d(x: Tensor, scale_factor: int) -> Tensor:
"""Implements a loss-less upsampling operation described in `"Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network" <https://arxiv.org/pdf/1609.05158.pdf>`_
by unstacking the channel axis into adjacent information.
Args:
x (torch.Tensor[N, C, H, W]): input tensor
scale_factor (int): spatial scaling factor
Returns:
torch.Tensor[N, C / scale_factor ** 2, H * scale_factor, W * scale_factor]: upsampled tensor
"""

b, c, h, w = x.shape

if (c % (scale_factor ** 2) != 0):
raise AssertionError("The number of channels in the input tensor must be a multiple of `scale_factor` squared")

# N * C * H * W --> N * scale_factor * scale_factor * (C / scale_factor ** 2) * H * W
x = x.view(b, scale_factor, scale_factor, c // int(scale_factor ** 2), h, w)
# --> N * (C / scale_factor ** 2) * H * scale_factor * W * scale_factor
x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
# --> N * (C / scale_factor ** 2) * (H * scale_factor) * (W * scale_factor)
x = x.view(b, c // int(scale_factor ** 2), h * scale_factor, w * scale_factor)

return x


def z_pool(x: Tensor, dim: int) -> Tensor:
"""Z-pool layer from `"Rotate to Attend: Convolutional Triplet Attention Module"
<https://arxiv.org/pdf/2010.03045.pdf>`_.
Expand Down
1 change: 1 addition & 0 deletions holocron/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .dropblock import *
from .attention import *
from .lambda_layer import *
from .upsample import *
6 changes: 3 additions & 3 deletions holocron/nn/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def forward(self, x: Tensor) -> Tensor:

class TripletAttention(nn.Module):
"""Triplet attention layer from `"Rotate to Attend: Convolutional Triplet Attention Module"
<https://arxiv.org/pdf/2010.03045.pdf>`_. This implementation is based on the pytorch
`implementation <https://github.com/LandskapeAI/triplet-attention/blob/master/MODELS/triplet_attention.py> `
by the paper's authors.
<https://arxiv.org/pdf/2010.03045.pdf>`_. This implementation is based on the
`one <https://github.com/LandskapeAI/triplet-attention/blob/master/MODELS/triplet_attention.py>`_
from the paper's authors.
"""
def __init__(self) -> None:
super().__init__()
Expand Down
12 changes: 12 additions & 0 deletions holocron/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class Add2d(_NormConvNd):
:math:`H` is a height of input planes in pixels, and :math:`W` is
width in pixels.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/add2d.png
:align: center
:alt: Add2D schema
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
Expand Down Expand Up @@ -229,6 +233,10 @@ class SlimConv2d(nn.Module):
where :math:`\\oplus` is the channel-wise concatenation.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/slimconv2d.png
:align: center
:alt: SlimConv2D schema
Args:
in_channels (int): Number of channels in the input image
Expand Down Expand Up @@ -296,6 +304,10 @@ class PyConv2d(nn.ModuleList):
"""Implements the convolution module from `"Pyramidal Convolution: Rethinking Convolutional Neural Networks for
Visual Recognition" <https://arxiv.org/pdf/2006.11538.pdf>`_.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/pyconv2d.png
:align: center
:alt: PyConv2D schema
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
Expand Down
3 changes: 3 additions & 0 deletions holocron/nn/modules/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class BlurPool2d(nn.Module):
module as described in `"Making Convolutional Networks Shift-Invariant Again"
<https://arxiv.org/pdf/1904.11486.pdf>`_.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/blurpool.png
:align: center
Args:
channels (int): Number of input channels
kernel_size (int, optional): binomial filter size for blurring. currently supports 3 (default) and 5.
Expand Down
3 changes: 3 additions & 0 deletions holocron/nn/modules/dropblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class DropBlock2d(nn.Module):
"""Implements the DropBlock module from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/pdf/1810.12890.pdf>`_
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/dropblock.png
:align: center
Args:
p (float, optional): probability of dropping activation value
block_size (int, optional): size of each block that is expended from the sampled mask
Expand Down
5 changes: 4 additions & 1 deletion holocron/nn/modules/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
class LambdaLayer(nn.Module):
"""Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
<https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
<https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
<https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`_.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png
:align: center
Args:
in_channels (int): input channels
Expand Down
34 changes: 34 additions & 0 deletions holocron/nn/modules/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2019-2021, François-Guillaume Fernandez.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from torch import Tensor
import torch.nn as nn
from .. import functional as F

__all__ = ['StackUpsample2d']


class StackUpsample2d(nn.Module):
"""Implements a loss-less upsampling operation described in `"Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network" <https://arxiv.org/pdf/1609.05158.pdf>`_
by unstacking the channel axis into adjacent information.
.. image:: https://docs.fast.ai/images/pixelshuffle.png
:align: center
Args:
scale_factor (int): spatial scaling factor
"""

def __init__(self, scale_factor: int) -> None:
super().__init__()
self.scale_factor = scale_factor

def forward(self, x: Tensor) -> Tensor:

return F.stack_upsample2d(x, self.scale_factor)

def extra_repr(self) -> str:
return f"scale_factor={self.scale_factor}"
3 changes: 3 additions & 0 deletions holocron/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def diou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor:
:math:`c` c is the diagonal length of the smallest enclosing box covering the two boxes,
and :math:`\\rho(.)` is the Euclidean distance.
.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/diou_loss.png
:align: center
Args:
boxes1 (torch.Tensor[M, 4]): bounding boxes
boxes2 (torch.Tensor[N, 4]): bounding boxes
Expand Down
31 changes: 31 additions & 0 deletions test/test_nn_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (C) 2019-2021, François-Guillaume Fernandez.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import pytest
import torch
from holocron.nn.modules import upsample
from holocron.nn import functional as F


def test_stackupsample2d():

num_batches = 2
num_chan = 4
x = torch.arange(num_batches * num_chan * 4 ** 2).view(num_batches, num_chan, 4, 4)

# Test functional API
with pytest.raises(AssertionError):
F.stack_upsample2d(x, 3)

# Check that it's the inverse of concat_downsample2d
x = torch.rand((num_batches, num_chan, 32, 32))
down = F.concat_downsample2d(x, scale_factor=2)
up = F.stack_upsample2d(down, scale_factor=2)
assert torch.equal(up, x)

# module interface
mod = upsample.StackUpsample2d(scale_factor=2)
assert torch.equal(mod(down), up)
assert repr(mod) == "StackUpsample2d(scale_factor=2)"

0 comments on commit 6dc9448

Please sign in to comment.