Skip to content

Commit

Permalink
feat: Added DropBlock implementation (#53)
Browse files Browse the repository at this point in the history
* feat: Added DropBlock implementation

* docs: Updated documentation

* test: Added unittest
  • Loading branch information
frgfm authored Jul 8, 2020
1 parent a901973 commit a786789
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 3 deletions.
7 changes: 6 additions & 1 deletion docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ Loss functions
.. autofunction:: ls_cross_entropy

Convolutions
--------------
------------

.. autofunction:: norm_conv2d

.. autofunction:: add2d

Regularization layers
---------------------

.. autofunction:: dropblock2d


Downsampling
------------
Expand Down
5 changes: 5 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Convolution layers

.. autoclass:: SlimConv2d

Regularization layers
---------------------

.. autoclass:: DropBlock2d


Downsampling
------------
Expand Down
34 changes: 33 additions & 1 deletion holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


__all__ = ['mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'norm_conv2d', 'add2d']
'norm_conv2d', 'add2d', 'dropblock2d']


def mish(x):
Expand Down Expand Up @@ -322,3 +322,35 @@ def add2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, norma
"""

return _xcorrNd(_addNd, x, weight, bias, stride, padding, dilation, groups, normalize_slices, eps)


def dropblock2d(x, drop_prob, block_size, inplace=False):
"""Implements the dropblock operation from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/pdf/1810.12890.pdf>`_
Args:
drop_prob (float): probability of dropping activation value
block_size (int): size of each block that is expended from the sampled mask
inplace (bool, optional): whether the operation should be done inplace
"""

# Sample a mask for the centers of blocks that will be dropped
mask = (torch.rand((x.shape[0], *x.shape[2:]), device=x.device) <= drop_prob).to(dtype=torch.float32)

# Expand zero positions to block size
mask = 1 - F.max_pool2d(mask, kernel_size=(block_size, block_size),
stride=(1, 1), padding=block_size // 2)

# Avoid NaNs
one_count = mask.sum()
if inplace:
x *= mask.unsqueeze(1)
if one_count > 0:
x *= mask.numel() / one_count
return x

out = x * mask.unsqueeze(1)
if one_count > 0:
out *= mask.numel() / one_count

return out
1 change: 1 addition & 0 deletions holocron/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .loss import *
from .downsample import *
from .conv import *
from .dropblock import *
38 changes: 38 additions & 0 deletions holocron/nn/modules/dropblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

'''
Regularization modules
'''

import torch
import torch.nn as nn
from .. import functional as F

__all__ = ['DropBlock2d']


class DropBlock2d(nn.Module):
"""Implements the DropBlock module from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/pdf/1810.12890.pdf>`_
Args:
p (float): probability of dropping activation value
block_size (int): size of each block that is expended from the sampled mask
inplace (bool, optional): whether the operation should be done inplace
"""

def __init__(self, p, block_size, inplace=False):
super().__init__()
self.p = p
self.block_size = block_size
self.inplace = inplace

@property
def drop_prob(self):
return self.p / self.block_size ** 2

def forward(self, x):
return F.dropblock2d(x, self.drop_prob, self.block_size, self.inplace)

def extra_repr(self):
return f"p={self.p}, block_size={self.block_size}, inplace={self.inplace}"
28 changes: 27 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from holocron.nn import functional as F
from holocron.nn.init import init_module
from holocron.nn.modules import activation, conv, loss, downsample
from holocron.nn.modules import activation, conv, loss, downsample, dropblock


class Tester(unittest.TestCase):
Expand Down Expand Up @@ -245,6 +245,32 @@ def test_slimconv2d(self):
out = mod(x)
self.assertEqual(out.shape, (2, 6, 19, 19))

def test_dropblock2d(self):

x = torch.rand(2, 8, 19, 19)

# Drop probability of 1
mod = dropblock.DropBlock2d(1., 1, inplace=False)

with torch.no_grad():
out = mod(x)
self.assertTrue(torch.equal(out, torch.zeros_like(x)))

# Drop probability of 0
mod = dropblock.DropBlock2d(0., 3, inplace=False)

with torch.no_grad():
out = mod(x)
self.assertTrue(torch.equal(out, x))
self.assertNotEqual(out.data_ptr, x.data_ptr)

# Check inplace
mod = dropblock.DropBlock2d(1., 3, inplace=True)

with torch.no_grad():
out = mod(x)
self.assertEqual(out.data_ptr, x.data_ptr)


act_fns = ['mish', 'nl_relu']

Expand Down

0 comments on commit a786789

Please sign in to comment.