Skip to content

Commit

Permalink
Merge 08f1c37 into 87810d0
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Feb 10, 2020
2 parents 87810d0 + 08f1c37 commit 0d928fa
Show file tree
Hide file tree
Showing 18 changed files with 1,688 additions and 77 deletions.
2 changes: 1 addition & 1 deletion tensorkit/arg_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def validate_conv_size(name: str,
if not value_ok:
raise ValueError(
f'`{name}` must be either a positive integer, or a sequence of '
f'positive integers with length `{spatial_ndims}`: got {value}.'
f'positive integers of length `{spatial_ndims}`: got {value}.'
)
return value

Expand Down
167 changes: 149 additions & 18 deletions tensorkit/backend/pytorch_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@

# shape utils
'shape', 'rank', 'reshape', 'repeat', 'expand', 'squeeze', 'expand_dim',
'swap_axes', 'transpose', 'pad',
'swap_axes', 'transpose',
'broadcast_shape', 'broadcast_to', 'explicit_broadcast', 'flatten_to_ndims',
'unflatten_from_ndims', 'reshape_tail',

# split / join / indexing / gathering ...
'index_select', 'concat', 'split', 'stack', 'unstack',
'index_select', 'concat', 'split', 'stack', 'unstack', 'slice', 'slice_axis',
'pad', 'pad_axis', 'shift', 'shift_axis',

# math operators
'abs', 'neg', 'square', 'exp', 'log', 'log1p', 'sin', 'cos', 'tan',
Expand Down Expand Up @@ -620,22 +621,6 @@ def transpose(input: Tensor, axis: List[int]) -> Tensor:
return input.permute(axis)


@jit
def pad(input: Tensor,
padding: List[Tuple[int, int]],
value: float = 0.) -> Tensor:
if len(padding) > input.dim():
raise ValueError(
'The length of `padding` must not be larger than `rank(input)`: '
'`padding` is {}, while `shape(input)` is {}'.
format(padding, shape(input))
)
pad: List[int] = []
for i in range(len(padding) - 1, -1, -1):
pad.extend(padding[i])
return torch.nn.functional.pad(input, pad=pad, value=value)


@jit
def broadcast_shape(x: List[int], y: List[int]) -> List[int]:
common_len = min(len(x), len(y))
Expand Down Expand Up @@ -818,6 +803,152 @@ def unstack(input: Tensor, axis: int) -> List[Tensor]:
return outputs


@jit
def slice_axis(input: Tensor,
axis: int,
start: int,
length: Optional[int] = None) -> Tensor:
if length is None:
if start < 0:
length = -start
else:
length = input.shape[axis] - start
return torch.narrow(input, axis, start, length)


@jit
def slice(input: Tensor,
slice_start: List[int],
slice_length: Optional[List[Optional[int]]] = None
) -> Tensor:
slice_count = len(slice_start)
if slice_count > input.dim():
raise ValueError(
'`len(slice_start)` must be less or equal to `rank(input)`: '
'got input shape {}, slice_start {}, slice_length {}.'.
format(shape(input), slice_start, slice_length)
)
if slice_length is None:
output = input
for i in range(-1, -(slice_count + 1), -1):
output = slice_axis(output, i, slice_start[i])
else:
if slice_count != len(slice_length):
raise ValueError('`len(slice_start)` != `len(slice_length)`: '
'got slice_start {}, slice_length {}.'.
format(slice_start, slice_length))
output = input
for i in range(-1, -(slice_count + 1), -1):
output = slice_axis(output, i, slice_start[i], slice_length[i])
return output


@jit
def pad_axis(input: Tensor,
axis: int,
padding: Tuple[int, int],
value: float = 0.) -> Tensor:
r = input.dim()
if axis < -r or axis >= r:
raise ValueError('`axis` out of range: expected to be >= {} and '
'< {}, got {}.'.format(-r, r, axis))
if axis < 0:
axis = axis + r
pad: List[int] = []
for i in range(r - 1, axis, -1):
pad.extend((0, 0))
pad.extend(padding)
return torch.nn.functional.pad(input, pad=pad, value=value)


@jit
def pad(input: Tensor,
padding: List[Tuple[int, int]],
value: float = 0.) -> Tensor:
if len(padding) > input.dim():
raise ValueError(
'The length of `padding` must not be larger than `rank(input)`: '
'`padding` is {}, while `shape(input)` is {}'.
format(padding, shape(input))
)
pad: List[int] = []
for i in range(len(padding) - 1, -1, -1):
pad.extend(padding[i])
return torch.nn.functional.pad(input, pad=pad, value=value)


@jit
def shift_axis(input: Tensor,
axis: int,
shift: int,
fill_value: float = 0.) -> Tensor:
size = input.shape[axis]
if shift < -size or shift > size:
raise ValueError('`shift` out of range: expected to be >= {} '
'and <= {}.'.format(-size, size))
if shift < 0:
output = pad_axis(
torch.narrow(input, axis, -shift, size + shift),
axis,
(0, -shift),
fill_value
)
elif shift > 0:
output = pad_axis(
torch.narrow(input, axis, 0, size - shift),
axis,
(shift, 0),
fill_value
)
else:
output = input
return output


@jit
def shift(input: Tensor,
shift: List[int],
fill_value: float = 0.) -> Tensor:
shift_length = len(shift)
if shift_length > input.dim():
raise ValueError('`len(shift) <= rank(input)` does not hold: '
'got `shift` {}, and `shape(input)` {}.'.
format(shift, shape(input)))

padding: List[int] = []
output = input
need_pad: bool = False

for axis in range(-1, -(shift_length + 1), -1):
s = shift[axis]
size = input.shape[axis]
if s < -size or s > size:
raise ValueError(
'`shift` out of range at axis {}: expected to be >= {} '
'and <= {}.'.format(axis, -size, size)
)
if s < 0:
padding.append(0)
padding.append(-s)
output = torch.narrow(output, axis, -s, size + s)
need_pad = True
elif s > 0:
padding.append(s)
padding.append(0)
output = torch.narrow(output, axis, 0, size - s)
need_pad = True
else:
padding.append(0)
padding.append(0)
axis -= 1

if need_pad:
output = torch.nn.functional.pad(
output, padding, mode='constant', value=fill_value)

return output


# ---- univariate element-wise math operations ----
@jit
def identity(input: Tensor) -> Tensor:
Expand Down
8 changes: 7 additions & 1 deletion tensorkit/backend/pytorch_/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
'calculate_fan_in_and_fan_out', 'get_activation_gain', 'apply_initializer',

# data-independent tensor initializers
'zeros', 'fill', 'uniform', 'normal', 'xavier_uniform', 'xavier_normal',
'zeros', 'ones', 'fill', 'uniform', 'normal',
'xavier_uniform', 'xavier_normal',
'kaming_uniform', 'kaming_normal',

# data-dependent layer initializers
Expand Down Expand Up @@ -184,6 +185,11 @@ def zeros(tensor: Tensor, **kwargs):
core.fill_zeros(tensor)


def ones(tensor: Tensor, **kwargs):
with no_grad():
core.fill(tensor, fill_value=1.)


def fill(tensor: Tensor, fill_value: Union[int, float, np.ndarray], **kwargs):
with no_grad():
core.fill(tensor, fill_value=float(fill_value))
Expand Down
31 changes: 26 additions & 5 deletions tensorkit/backend/pytorch_/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
'BaseLayer', 'BaseSingleVariateLayer', 'BaseMultiVariateLayer',
'BaseSplitLayer', 'BaseMergeLayer',
'ModuleList', 'Sequential',
'BaseContextualLayer',
'BaseContextualLayer', 'BaseMultiVariateContextualLayer',

# linear layers
'CoreLinear', 'Linear',
Expand Down Expand Up @@ -378,17 +378,38 @@ def forward(self, inputs: List[Tensor]) -> Tensor:

class BaseContextualLayer(BaseLayer):
"""
Base class layers that produces the output according to the input tensor
and potentially a contextual tensor.
Base class layers that produces the output according to the input tensor
and contextual tensors.
"""

def _call(self, input: Tensor, context: Optional[Tensor]) -> Tensor:
def _call(self, input: Tensor, context: List[Tensor]) -> Tensor:
raise NotImplementedError()

def forward(self, input: Tensor, context: Optional[Tensor] = None) -> Tensor:
def forward(self,
input: Tensor,
context: Optional[List[Tensor]] = None) -> Tensor:
if context is None:
context = []
return self._call(input, context)


class BaseMultiVariateContextualLayer(BaseLayer):
"""
Base class layers that produces the output tensors according to the
input tensors and contextual tensors.
"""

def _call(self, inputs: List[Tensor], context: List[Tensor]) -> List[Tensor]:
raise NotImplementedError()

def forward(self,
inputs: List[Tensor],
context: Optional[List[Tensor]] = None) -> List[Tensor]:
if context is None:
context = []
return self._call(inputs, context)


class Sequential(torch_nn.Sequential):

def __init__(self, *layers: Union[Module, Sequence[Module]]):
Expand Down
17 changes: 13 additions & 4 deletions tensorkit/backend/pytorch_/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from .core import *

__all__ = [
# constants
'LEAKY_RELU_DEFAULT_SLOPE', 'AVG_POOL_DEFAULT_COUNT_PADDED_ZEROS',

# activation functions
'LEAKY_RELU_DEFAULT_SLOPE',
'relu', 'leaky_relu',
'sigmoid', 'log_sigmoid',
'softmax', 'log_softmax',
Expand All @@ -30,6 +32,7 @@

# ---- activation functions ----
LEAKY_RELU_DEFAULT_SLOPE = 0.01
AVG_POOL_DEFAULT_COUNT_PADDED_ZEROS = False


@jit
Expand Down Expand Up @@ -418,39 +421,43 @@ def depth_to_space3d(input: Tensor, block_size: int) -> Tensor:


# ---- pooling functions ----
@jit
def avg_pool1d(input: Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int],
count_padded_zeros: bool = True):
count_padded_zeros: bool = AVG_POOL_DEFAULT_COUNT_PADDED_ZEROS):
return torch.nn.functional.avg_pool1d(
input, kernel_size=kernel_size, stride=stride, padding=padding,
count_include_pad=count_padded_zeros,
)


@jit
def avg_pool2d(input: Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int],
count_padded_zeros: bool = True):
count_padded_zeros: bool = AVG_POOL_DEFAULT_COUNT_PADDED_ZEROS):
return torch.nn.functional.avg_pool2d(
input, kernel_size=kernel_size, stride=stride, padding=padding,
count_include_pad=count_padded_zeros,
)


@jit
def avg_pool3d(input: Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int],
count_padded_zeros: bool = True):
count_padded_zeros: bool = AVG_POOL_DEFAULT_COUNT_PADDED_ZEROS):
return torch.nn.functional.avg_pool3d(
input, kernel_size=kernel_size, stride=stride, padding=padding,
count_include_pad=count_padded_zeros,
)


@jit
def max_pool1d(input: Tensor,
kernel_size: List[int],
stride: List[int],
Expand All @@ -459,6 +466,7 @@ def max_pool1d(input: Tensor,
input, kernel_size=kernel_size, stride=stride, padding=padding)


@jit
def max_pool2d(input: Tensor,
kernel_size: List[int],
stride: List[int],
Expand All @@ -467,6 +475,7 @@ def max_pool2d(input: Tensor,
input, kernel_size=kernel_size, stride=stride, padding=padding)


@jit
def max_pool3d(input: Tensor,
kernel_size: List[int],
stride: List[int],
Expand Down
1 change: 1 addition & 0 deletions tensorkit/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .core import *
from .flow_layer import *
from .gated import *
from .pixelcnn import *
from .pool import *
from .resnet import *
from .shape_ import *
Expand Down

0 comments on commit 0d928fa

Please sign in to comment.