Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #137 from DerThorsten/master
Browse files Browse the repository at this point in the history
unet cleanup
  • Loading branch information
DerThorsten committed Aug 15, 2018
2 parents ef7e404 + 572e197 commit cb2ac7f
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 141 deletions.
8 changes: 4 additions & 4 deletions examples/plot_unet_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,24 @@ def conv_op_factory(self, in_channels, out_channels, part, index):
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
), False
elif part == 'bottom':
return torch.nn.Sequential(
ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
)
), False
elif part == 'up':
# are we in the very last block?
if index + 1 == self.depth:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
), False
else:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
), False
else:
raise RuntimeError("something is wrong")

Expand Down
7 changes: 1 addition & 6 deletions inferno/extensions/containers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...utils import python_utils as pyu
from ...utils.exceptions import assert_
from ..layers.device import OnDevice

from ..layers.identity import Identity

__all__ = ['NNGraph', 'Graph']

Expand Down Expand Up @@ -48,11 +48,6 @@ def copy(self, **init_kwargs):
return new


class Identity(nn.Module):
"""A torch.nn.Module to do nothing."""
def forward(self, input):
return input


class Graph(nn.Module):
"""
Expand Down
11 changes: 8 additions & 3 deletions inferno/extensions/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from .convolutional import *
from .device import *
from .reshape import *
from .unet import *
from .unet_base import *
from .res_unet import *
from .building_blocks import *

#######################################################
Expand All @@ -13,14 +14,18 @@
from .convolutional import _all as _convolutional_all
from .device import _all as _device_all
from .reshape import _all as _reshape_all
from .unet import _all as _unet_all
from .unet_base import _all as _unet_base_all
from .res_unet import _all as _res_unet_all
from .building_blocks import _all as _building_blocks_all
from .identity import _all as _identity_all

__all__.extend(_activations_all)
__all__.extend(_convolutional_all)
__all__.extend(_device_all)
__all__.extend(_reshape_all)
__all__.extend(_unet_all)
__all__.extend(_unet_base_all)
__all__.extend(_res_unet_all)
__all__.extend(_building_blocks_all)
__all__.extend(_identity_all)

_all = __all__
10 changes: 10 additions & 0 deletions inferno/extensions/layers/identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch.nn as nn
__all__ = ['identity']
_all = __all__

class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()

def forward(self, x):
return x
65 changes: 65 additions & 0 deletions inferno/extensions/layers/res_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn

from .building_blocks import ResBlock
from .unet_base import UNetBase
from ...utils.python_utils import require_dict_kwagrs


__all__ = ['ResBlockUNet']
_all = __all__

class ResBlockUNet(UNetBase):
"""TODO.
ACCC
Attributes:
activated (TYPE): Description
dim (TYPE): Description
res_block_kwargs (TYPE): Description
side_out_parts (TYPE): Description
unet_kwargs (TYPE): Description
"""
def __init__(self, in_channels, dim, out_channels, unet_kwargs=None,
res_block_kwargs=None, activated=True,
side_out_parts=None
):

self.dim = dim
self.unet_kwargs = require_dict_kwagrs(unet_kwargs, "unet_kwargs must be a dict or None")
self.res_block_kwargs = require_dict_kwagrs(res_block_kwargs, "res_block_kwargs must be a dict or None")
self.activated = activated
if isinstance(side_out_parts, str):
self.side_out_parts = set([side_out_parts])
elif isinstance(side_out_parts, (tuple,list)):
self.side_out_parts = set(side_out_parts)
else:
self.side_out_parts = set()

super(ResBlockUNet, self).__init__(
in_channels=in_channels,
dim=dim,
out_channels=out_channels,
**self.unet_kwargs
)



def conv_op_factory(self, in_channels, out_channels, part, index):

# is this the very last convolutional block?
very_last = (part == 'up' and index + 1 == self.depth)


# should the residual block be activated?
activated = not very_last or self.activated

# should the output be part of the overall
# return-list in the forward pass of the UNet
use_as_output = part in self.side_out_parts

# residual block used within the UNet
return ResBlock(in_channels=in_channels, out_channels=out_channels,
dim=self.dim, activated=activated,
**self.res_block_kwargs), use_as_output
4 changes: 4 additions & 0 deletions inferno/extensions/layers/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ def __init__(self, downscale_factor):
super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1),
stride=(1, ds, ds),
padding=(0, 1, 1))




0 comments on commit cb2ac7f

Please sign in to comment.