Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #138 from DerThorsten/master
Browse files Browse the repository at this point in the history
fixed unet examples
  • Loading branch information
DerThorsten committed Aug 15, 2018
2 parents cb2ac7f + b96a335 commit 90179a8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
16 changes: 7 additions & 9 deletions examples/res_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
"""
import torch
import inferno.extensions.layers.unet as unet

from inferno.extensions.layers import ConvELU2D, ConvSigmoid2D
from inferno.extensions.layers import (ResBlockUNet, UNetBase, ConvELU2D, ConvSigmoid2D)



Expand All @@ -19,7 +17,7 @@
x = torch.autograd.Variable(x)

# a unet with resiudal blocks
model = unet.ResBlockUNet(in_channels=10, out_channels=20, dim=2)
model = ResBlockUNet(in_channels=10, out_channels=20, dim=2)

# pass x trough unet
out = model(x)
Expand All @@ -43,14 +41,14 @@
x = torch.autograd.Variable(x)

# a unet with resiudal blocks
model_a = unet.ResBlockUNet(in_channels=5, out_channels=12, dim=3,
model_a = ResBlockUNet(in_channels=5, out_channels=12, dim=3,
unet_kwargs=dict(depth=3))

# if the last layer in the second unet
# shall be non-activated we set
# activated to False, this will only affect the
# very last convolution of the net
model_b = unet.ResBlockUNet(in_channels=12, out_channels=2, dim=3,
model_b = ResBlockUNet(in_channels=12, out_channels=2, dim=3,
activated=False,
unet_kwargs=dict(depth=3))

Expand All @@ -72,7 +70,7 @@
a custom UNet by deriving from UNetBase
"""

class MySimple2DUnet(unet.UNetBase):
class MySimple2DUnet(UNetBase):
def __init__(self, **kwargs):
super(MySimple2DUnet, self).__init__(dim=2, **kwargs)

Expand All @@ -85,13 +83,13 @@ def conv_op_factory(self, in_channels, out_channels, part, index):
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
ConvSigmoid2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
), False
else:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
)
), False



Expand Down
2 changes: 1 addition & 1 deletion examples/train_side_loss_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from inferno.extensions.layers.convolutional import Conv2D
from inferno.extensions.layers.building_blocks import ResBlock
from inferno.extensions.layers.unet import ResBlockUNet
from inferno.extensions.layers import ResBlockUNet
from inferno.utils.torch_utils import unwrap
from inferno.utils.python_utils import ensure_dir
import pylab
Expand Down
2 changes: 1 addition & 1 deletion tests/extensions/layers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from inferno.extensions.layers import UNetBase
from inferno.extensions.layers import ResBlockUNet
from inferno.extensions.layers.convolutional import ConvELU2D

-



Expand Down

0 comments on commit 90179a8

Please sign in to comment.