Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added StackUpsample2d #132

Merged
merged 8 commits into from
May 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ Downsampling
.. autofunction:: concat_downsample2d
.. autofunction:: z_pool

Upsampling
----------

.. autofunction:: stack_upsample2d
6 changes: 6 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ Attention
.. autoclass:: LambdaLayer

.. autoclass:: TripletAttention


Upsampling
------------

.. autoclass:: StackUpsample2d
12 changes: 12 additions & 0 deletions holocron/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def unet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UNet
"""U-Net from
`"U-Net: Convolutional Networks for Biomedical Image Segmentation" <https://arxiv.org/pdf/1505.04597.pdf>`_

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unet.png
:align: center

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -449,6 +452,9 @@ def unetp(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UNe
"""UNet+ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation"
<https://arxiv.org/pdf/1912.05074.pdf>`_

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetp.png
:align: center

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -464,6 +470,9 @@ def unetpp(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UN
"""UNet++ from `"UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation"
<https://arxiv.org/pdf/1912.05074.pdf>`_

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unetpp.png
:align: center

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand All @@ -479,6 +488,9 @@ def unet3p(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> UN
"""UNet3+ from
`"UNet 3+: A Full-Scale Connected UNet For Medical Image Segmentation" <https://arxiv.org/pdf/2004.08790.pdf>`_

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/unet3p.png
:align: center

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
Expand Down
31 changes: 30 additions & 1 deletion holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


__all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool']
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool',
'concat_downsample2d', 'stack_upsample2d']


def silu(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -154,6 +155,34 @@ def concat_downsample2d(x: Tensor, scale_factor: int) -> Tensor:
return x


def stack_upsample2d(x: Tensor, scale_factor: int) -> Tensor:
"""Implements a loss-less upsampling operation described in `"Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network" <https://arxiv.org/pdf/1609.05158.pdf>`_
by unstacking the channel axis into adjacent information.

Args:
x (torch.Tensor[N, C, H, W]): input tensor
scale_factor (int): spatial scaling factor

Returns:
torch.Tensor[N, C / scale_factor ** 2, H * scale_factor, W * scale_factor]: upsampled tensor
"""

b, c, h, w = x.shape

if (c % (scale_factor ** 2) != 0):
raise AssertionError("The number of channels in the input tensor must be a multiple of `scale_factor` squared")

# N * C * H * W --> N * scale_factor * scale_factor * (C / scale_factor ** 2) * H * W
x = x.view(b, scale_factor, scale_factor, c // int(scale_factor ** 2), h, w)
# --> N * (C / scale_factor ** 2) * H * scale_factor * W * scale_factor
x = x.permute(0, 3, 4, 1, 5, 2).contiguous()
# --> N * (C / scale_factor ** 2) * (H * scale_factor) * (W * scale_factor)
x = x.view(b, c // int(scale_factor ** 2), h * scale_factor, w * scale_factor)

return x


def z_pool(x: Tensor, dim: int) -> Tensor:
"""Z-pool layer from `"Rotate to Attend: Convolutional Triplet Attention Module"
<https://arxiv.org/pdf/2010.03045.pdf>`_.
Expand Down
1 change: 1 addition & 0 deletions holocron/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .dropblock import *
from .attention import *
from .lambda_layer import *
from .upsample import *
6 changes: 3 additions & 3 deletions holocron/nn/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def forward(self, x: Tensor) -> Tensor:

class TripletAttention(nn.Module):
"""Triplet attention layer from `"Rotate to Attend: Convolutional Triplet Attention Module"
<https://arxiv.org/pdf/2010.03045.pdf>`_. This implementation is based on the pytorch
`implementation <https://github.com/LandskapeAI/triplet-attention/blob/master/MODELS/triplet_attention.py> `
by the paper's authors.
<https://arxiv.org/pdf/2010.03045.pdf>`_. This implementation is based on the
`one <https://github.com/LandskapeAI/triplet-attention/blob/master/MODELS/triplet_attention.py>`_
from the paper's authors.
"""
def __init__(self) -> None:
super().__init__()
Expand Down
12 changes: 12 additions & 0 deletions holocron/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class Add2d(_NormConvNd):
:math:`H` is a height of input planes in pixels, and :math:`W` is
width in pixels.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/add2d.png
:align: center
:alt: Add2D schema

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
Expand Down Expand Up @@ -229,6 +233,10 @@ class SlimConv2d(nn.Module):

where :math:`\\oplus` is the channel-wise concatenation.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/slimconv2d.png
:align: center
:alt: SlimConv2D schema


Args:
in_channels (int): Number of channels in the input image
Expand Down Expand Up @@ -296,6 +304,10 @@ class PyConv2d(nn.ModuleList):
"""Implements the convolution module from `"Pyramidal Convolution: Rethinking Convolutional Neural Networks for
Visual Recognition" <https://arxiv.org/pdf/2006.11538.pdf>`_.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/pyconv2d.png
:align: center
:alt: PyConv2D schema

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
Expand Down
3 changes: 3 additions & 0 deletions holocron/nn/modules/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class BlurPool2d(nn.Module):
module as described in `"Making Convolutional Networks Shift-Invariant Again"
<https://arxiv.org/pdf/1904.11486.pdf>`_.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/blurpool.png
:align: center

Args:
channels (int): Number of input channels
kernel_size (int, optional): binomial filter size for blurring. currently supports 3 (default) and 5.
Expand Down
3 changes: 3 additions & 0 deletions holocron/nn/modules/dropblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class DropBlock2d(nn.Module):
"""Implements the DropBlock module from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/pdf/1810.12890.pdf>`_

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/dropblock.png
:align: center

Args:
p (float, optional): probability of dropping activation value
block_size (int, optional): size of each block that is expended from the sampled mask
Expand Down
5 changes: 4 additions & 1 deletion holocron/nn/modules/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
class LambdaLayer(nn.Module):
"""Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
<https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
<https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
<https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`_.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/lambdalayer.png
:align: center

Args:
in_channels (int): input channels
Expand Down
34 changes: 34 additions & 0 deletions holocron/nn/modules/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2019-2021, François-Guillaume Fernandez.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from torch import Tensor
import torch.nn as nn
from .. import functional as F

__all__ = ['StackUpsample2d']


class StackUpsample2d(nn.Module):
"""Implements a loss-less upsampling operation described in `"Real-Time Single Image and Video Super-Resolution
Using an Efficient Sub-Pixel Convolutional Neural Network" <https://arxiv.org/pdf/1609.05158.pdf>`_
by unstacking the channel axis into adjacent information.

.. image:: https://docs.fast.ai/images/pixelshuffle.png
:align: center

Args:
scale_factor (int): spatial scaling factor
"""

def __init__(self, scale_factor: int) -> None:
super().__init__()
self.scale_factor = scale_factor

def forward(self, x: Tensor) -> Tensor:

return F.stack_upsample2d(x, self.scale_factor)

def extra_repr(self) -> str:
return f"scale_factor={self.scale_factor}"
3 changes: 3 additions & 0 deletions holocron/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def diou_loss(boxes1: Tensor, boxes2: Tensor) -> Tensor:
:math:`c` c is the diagonal length of the smallest enclosing box covering the two boxes,
and :math:`\\rho(.)` is the Euclidean distance.

.. image:: https://github.com/frgfm/Holocron/releases/download/v0.1.3/diou_loss.png
:align: center

Args:
boxes1 (torch.Tensor[M, 4]): bounding boxes
boxes2 (torch.Tensor[N, 4]): bounding boxes
Expand Down
31 changes: 31 additions & 0 deletions test/test_nn_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (C) 2019-2021, François-Guillaume Fernandez.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import pytest
import torch
from holocron.nn.modules import upsample
from holocron.nn import functional as F


def test_stackupsample2d():

num_batches = 2
num_chan = 4
x = torch.arange(num_batches * num_chan * 4 ** 2).view(num_batches, num_chan, 4, 4)

# Test functional API
with pytest.raises(AssertionError):
F.stack_upsample2d(x, 3)

# Check that it's the inverse of concat_downsample2d
x = torch.rand((num_batches, num_chan, 32, 32))
down = F.concat_downsample2d(x, scale_factor=2)
up = F.stack_upsample2d(down, scale_factor=2)
assert torch.equal(up, x)

# module interface
mod = upsample.StackUpsample2d(scale_factor=2)
assert torch.equal(mod(down), up)
assert repr(mod) == "StackUpsample2d(scale_factor=2)"