Skip to content

Commit

Permalink
expose AddOnesChannelNd
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed May 12, 2020
1 parent aef0688 commit d508787
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 54 deletions.
1 change: 1 addition & 0 deletions tensorkit/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .composed import *
from .contextual import *
from .core import *
from .edge_bias_conv_ import *
from .flow_layer import *
from .gated import *
from .pixelcnn import *
Expand Down
50 changes: 50 additions & 0 deletions tensorkit/layers/edge_bias_conv_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from .. import tensor as T
from ..tensor import Tensor, shape, concat, ones_like
from .core import *

__all__ = ['AddOnesChannel1d', 'AddOnesChannel2d', 'AddOnesChannel3d']


class AddOnesChannelNd(BaseLayer):

__constants__ = ('_channel_axis', '_spatial_ndims')

_channel_axis: int
_spatial_ndims: int

def __init__(self):
super().__init__()
spatial_ndims = self._get_spatial_ndims()
self._spatial_ndims = spatial_ndims
if T.IS_CHANNEL_LAST:
self._channel_axis = -1
else:
self._channel_axis = -(spatial_ndims + 1)

def _get_spatial_ndims(self) -> int:
raise NotImplementedError()

def forward(self, input: Tensor) -> Tensor:
channel_shape = shape(input)
channel_shape[self._channel_axis] = 1

return concat([input, ones_like(input, shape=channel_shape)],
axis=self._channel_axis)


class AddOnesChannel1d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 1


class AddOnesChannel2d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 2


class AddOnesChannel3d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 3
63 changes: 10 additions & 53 deletions tensorkit/layers/pixelcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from .. import tensor as T
from ..arg_check import *
from ..tensor import Tensor, Module, rank, shift, shape, concat, ones_like
from ..tensor import Tensor, Module, rank, shift
from ..typing_ import *
from . import resnet, core, composed
from . import resnet, core, composed, conv_edge_bias
from .core import *
from .utils import flatten_nested_layers

Expand Down Expand Up @@ -109,51 +109,6 @@ def forward(self, input: Tensor) -> Tensor:
return output


class AddOnesChannelNd(BaseLayer):

__constants__ = ('_channel_axis', '_spatial_ndims')

_channel_axis: int
_spatial_ndims: int

def __init__(self):
super().__init__()
spatial_ndims = self._get_spatial_ndims()
self._spatial_ndims = spatial_ndims
if T.IS_CHANNEL_LAST:
self._channel_axis = -1
else:
self._channel_axis = -(spatial_ndims + 1)

def _get_spatial_ndims(self) -> int:
raise NotImplementedError()

def forward(self, input: Tensor) -> Tensor:
channel_shape = shape(input)
channel_shape[self._channel_axis] = 1

return concat([input, ones_like(input, shape=channel_shape)],
axis=self._channel_axis)


class AddOnesChannel1d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 1


class AddOnesChannel2d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 2


class AddOnesChannel3d(AddOnesChannelNd):

def _get_spatial_ndims(self) -> int:
return 3


class AddLeadingContext(BaseLayer):

__constants__ = ('first_n',)
Expand Down Expand Up @@ -245,7 +200,7 @@ def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
add_ones_channel: bool = True,
edge_bias: bool = True,
weight_norm: WeightNormArgType = False,
weight_init: TensorInitArgType = DEFAULT_WEIGHT_INIT,
bias_init: TensorInitArgType = DEFAULT_BIAS_INIT,
Expand All @@ -263,8 +218,10 @@ def __init__(self,
The actual kernel size used by the convolutional layers
will be re-calculated under the guide of this kernel size,
in order to ensure causality between pixels.
add_ones_channel: Whether or not add a channel to the input,
with all elements set to `1`?
edge_bias: Whether or not add a channel to the input,
with all elements set to `1`? This will effectively make
the padded edges (i.e., with zero values) different from the
true pixel values.
weight_norm: The weight norm mode for the convolutional layers.
If :obj:`True`, will use "full" weight norm for "conv1" and
"shortcut". For "conv0", will use "full" if `normalizer`
Expand All @@ -278,16 +235,16 @@ def __init__(self,
"""
super().__init__()

globals_dict = globals()
spatial_ndims = self._get_spatial_ndims()
kernel_size = validate_conv_size('kernel_size', kernel_size, spatial_ndims)

# construct the layer
super().__init__()
self._spatial_ndims = spatial_ndims

if add_ones_channel:
self.add_ones_channel = globals_dict[f'AddOnesChannel{spatial_ndims}d']()
if edge_bias:
self.add_ones_channel = getattr(
conv_edge_bias, f'AddOnesChannel{spatial_ndims}d')()
in_channels += 1
else:
self.add_ones_channel = Identity()
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_pixelcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_causality_and_receptive_field(self):
input_layer_cls = getattr(
tk.layers, f'PixelCNNInput{spatial_ndims}d')
input_layer = input_layer_cls(
1, 1, kernel_size=kernel_size, add_ones_channel=False,
1, 1, kernel_size=kernel_size, edge_bias=False,
weight_init=tk.init.ones,
)
input_layer = tk.layers.jit_compile(input_layer)
Expand Down

0 comments on commit d508787

Please sign in to comment.