Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Removed implementations of nn that are now integrated in PyTorch #157

Merged
merged 7 commits into from
Nov 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Implementations of recent Deep Learning tricks in Computer Vision, easily paired
## Quick Tour

### PyTorch layers for every need
- Activation: [SiLU/Swish](https://arxiv.org/abs/1606.08415), [Mish](https://arxiv.org/abs/1908.08681), [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824)
- Activation: [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189), [MutualChannelLoss](https://arxiv.org/abs/2002.04264)
- Convolutions: [NormConv2d](https://arxiv.org/pdf/2005.05274v2.pdf), [Add2d](https://arxiv.org/pdf/1912.13200.pdf), [SlimConv2d](https://arxiv.org/pdf/2003.07469.pdf), [PyConv2d](https://arxiv.org/abs/2006.11538), [Involution](https://arxiv.org/abs/2103.06255)
- Regularization: [DropBlock](https://arxiv.org/abs/1810.12890)
Expand Down
4 changes: 0 additions & 4 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ holocron.nn.functional
Non-linear activations
----------------------

.. autofunction:: silu

.. autofunction:: mish

.. autofunction:: hard_mish

.. autofunction:: nl_relu
Expand Down
4 changes: 0 additions & 4 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ An addition to the :mod:`torch.nn` module of Pytorch to extend the range of neur
Non-linear activations
----------------------

.. autoclass:: SiLU

.. autoclass:: Mish

.. autoclass:: HardMish

.. autoclass:: NLReLU
Expand Down
4 changes: 2 additions & 2 deletions holocron/models/darknetv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn

from holocron.nn import DropBlock2d, GlobalAvgPool2d, Mish
from holocron.nn import DropBlock2d, GlobalAvgPool2d

from ..nn.init import init_module
from .darknetv3 import ResBlock
Expand Down Expand Up @@ -178,7 +178,7 @@ def cspdarknet53_mish(pretrained: bool = False, progress: bool = True, **kwargs:
torch.nn.Module: classification model
"""

kwargs['act_layer'] = Mish()
kwargs['act_layer'] = nn.Mish(inplace=True)
kwargs['drop_layer'] = DropBlock2d

return _darknet('cspdarknet53_mish', pretrained, progress, **kwargs) # type: ignore[return-value]
4 changes: 2 additions & 2 deletions holocron/models/detection/yolov4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchvision.ops.boxes import box_iou, nms
from torchvision.ops.misc import FrozenBatchNorm2d

from holocron.nn import SPP, Mish
from holocron.nn import SPP
from holocron.nn.init import init_module
from holocron.ops.boxes import ciou_loss

Expand Down Expand Up @@ -456,7 +456,7 @@ def __init__(
backbone_norm_layer = norm_layer

# backbone
self.backbone = DarknetBodyV4(layout, in_channels, stem_channels, 3, Mish(),
self.backbone = DarknetBodyV4(layout, in_channels, stem_channels, 3, nn.Mish(inplace=True),
backbone_norm_layer, drop_layer, conv_layer)
# neck
self.neck = Neck([1024, 512, 256], act_layer, norm_layer, drop_layer, conv_layer)
Expand Down
8 changes: 4 additions & 4 deletions holocron/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch.nn as nn

from holocron.nn import GlobalAvgPool2d, SiLU, init
from holocron.nn import GlobalAvgPool2d, init

from .utils import conv_sequence, load_pretrained_params

Expand Down Expand Up @@ -66,8 +66,8 @@ def __init__(self, in_channels, channels, t, stride, use_se=True, se_ratio=12,
_layers = []
if t != 1:
dw_channels = in_channels * t
_layers.extend(conv_sequence(in_channels, dw_channels, SiLU(), norm_layer, drop_layer, kernel_size=1,
stride=1, bias=False))
_layers.extend(conv_sequence(in_channels, dw_channels, nn.SiLU(inplace=True), norm_layer, drop_layer,
kernel_size=1, stride=1, bias=False))
else:
dw_channels = in_channels

Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, width_mult=1.0, depth_mult=1.0, num_classes=1000, in_channels
super().__init__()

if act_layer is None:
act_layer = SiLU()
act_layer = nn.SiLU(inplace=True)
if norm_layer is None:
norm_layer = nn.BatchNorm2d

Expand Down
6 changes: 3 additions & 3 deletions holocron/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchvision.models import resnet34, vgg11
from torchvision.models._utils import IntermediateLayerGetter

from ...nn import GlobalAvgPool2d, SiLU
from ...nn import GlobalAvgPool2d
from ...nn.init import init_module
from ..rexnet import rexnet1_3x
from ..utils import conv_sequence, load_pretrained_params
Expand Down Expand Up @@ -503,8 +503,8 @@ def unet_rexnet13(

backbone = rexnet1_3x(pretrained=pretrained_backbone and not pretrained, in_channels=in_channels).features
kwargs['final_upsampling'] = kwargs.get('final_upsampling', True)
kwargs['act_layer'] = kwargs.get('act_layer', SiLU())
kwargs['act_layer'] = kwargs.get('act_layer', nn.SiLU(inplace=True))
# hotfix of https://github.com/pytorch/vision/issues/3802
backbone[21] = SiLU()
backbone[21] = nn.SiLU(inplace=True)

return _dynamic_unet('unet_rexnet13', backbone, pretrained, progress, **kwargs)
26 changes: 1 addition & 25 deletions holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,11 @@
import torch.nn.functional as F
from torch import Tensor

__all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
__all__ = ['hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool',
'concat_downsample2d']


def silu(x: Tensor) -> Tensor:
"""Implements the SiLU activation function

Args:
x: input tensor
Returns:
output tensor
"""

return x * torch.sigmoid(x)


def mish(x: Tensor) -> Tensor:
"""Implements the Mish activation function

Args:
x: input tensor
Returns:
output tensor
"""

return x * torch.tanh(F.softplus(x))


def hard_mish(x: Tensor, inplace: bool = False) -> Tensor:
"""Implements the HardMish activation function

Expand Down
41 changes: 1 addition & 40 deletions holocron/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .. import functional as F

__all__ = ['SiLU', 'Mish', 'HardMish', 'NLReLU', 'FReLU']
__all__ = ['HardMish', 'NLReLU', 'FReLU']


class _Activation(nn.Module):
Expand All @@ -27,45 +27,6 @@ def extra_repr(self) -> str:
return inplace_str


class _SiLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return F.silu(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sig = torch.sigmoid(x)
return grad_output * sig * (1 + x * (1 - sig))


class SiLU(nn.Module):
"""Implements the SiLU activation from `"Gaussian Error Linear Units (GELUs)"
<https://arxiv.org/pdf/1606.08415.pdf>`_ (also known as Swish).

This activation is computed as follows:

.. math::
f(x) = x \\cdot \\sigma(x)
"""
def forward(self, x: Tensor) -> Tensor:
return _SiLU.apply(x)


class Mish(nn.Module):
"""Implements the Mish activation module from `"Mish: A Self Regularized Non-Monotonic Neural Activation Function"
<https://arxiv.org/pdf/1908.08681.pdf>`_

This activation is computed as follows:

.. math::
f(x) = x \\cdot \\tanh(ln(1 + e^x))
"""
def forward(self, x: Tensor) -> Tensor:
return F.mish(x)


class HardMish(_Activation):
"""Implements the Had Mish activation module from `"H-Mish" <https://github.com/digantamisra98/H-Mish>`_

Expand Down
10 changes: 0 additions & 10 deletions test/test_nn_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ def _test_activation_function(fn, input_shape):
assert x.data_ptr() == out.data_ptr()


def test_silu():
_test_activation_function(F.silu, (4, 3, 32, 32))
assert repr(activation.SiLU()) == "SiLU()"


def test_mish():
_test_activation_function(F.mish, (4, 3, 32, 32))
assert repr(activation.Mish()) == "Mish()"


def test_hard_mish():
_test_activation_function(F.hard_mish, (4, 3, 32, 32))
assert repr(activation.HardMish()) == "HardMish()"
Expand Down