Skip to content

Commit

Permalink
some optimization on core ops
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Feb 28, 2020
1 parent 1a74537 commit fda99fd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 53 deletions.
79 changes: 29 additions & 50 deletions tensorkit/backend/pytorch_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,36 +667,23 @@ def reshape(input: Tensor, shape: List[int]) -> Tensor:

@jit
def repeat(input: Tensor, repeats: List[int]) -> Tensor:
x_shape = input.shape
x_rank = len(x_shape)
repeats_len = len(repeats)
extra_len = repeats_len - x_rank

# argument check
if extra_len < 0:
repeats = [1] * (len(x_shape) - len(repeats)) + repeats
extra_len = 0

# detect the repeat mode
mode = 0 # 0 = return directly, 1 = expand, 2 = repeat
if extra_len > 0:
mode = 1

for i in range(len(x_shape)):
a = x_shape[i]
b = repeats[i + extra_len]
if b != 1:
if a != 1:
mode = 2
else:
mode = max(1, mode)

# do repeat the tensor according to different mode
if mode == 0:
in_shape = list(input.shape)
in_shape_len, repeats_len = len(in_shape), len(repeats)
max_length = max(in_shape_len, repeats_len)
in_shape = [1] * (max_length - in_shape_len) + in_shape
repeats = [1] * (max_length - repeats_len) + repeats

mode = max([
(1 if repeats[i] != 1 else 0) + (1 if in_shape[i] != 1 else 0)
for i in range(max_length)
])

if mode == 0 and in_shape_len == max_length:
return input
elif mode == 1:
elif mode < 2:
extra_len = max_length - in_shape_len
expands = repeats[:extra_len] + \
list([-1 if a == 1 else a for a in repeats[extra_len:]])
[-1 if a == 1 else a for a in repeats[extra_len:]]
return input.expand(expands)
else:
return input.repeat(repeats)
Expand All @@ -721,7 +708,7 @@ def squeeze(input: Tensor, axis: Optional[List[int]] = None) -> Tensor:
else:
raise ValueError('Axis {} cannot be squeezed, since its '
'size is {} != 1'.format(a, old_shape[a]))
new_shape = torch.jit.annotate(List[int], [])
new_shape: List[int] = []
for i in range(len(old_shape)):
if new_shape_mask[i]:
new_shape.append(old_shape[i])
Expand All @@ -747,27 +734,19 @@ def transpose(input: Tensor, axis: List[int]) -> Tensor:

@jit
def broadcast_shape(x: List[int], y: List[int]) -> List[int]:
common_len = min(len(x), len(y))

right = torch.jit.annotate(List[int], [])
for i in range(common_len):
a = x[i - common_len]
b = y[i - common_len]
if a == 1:
right.append(b)
elif b == 1:
right.append(a)
elif a != b:
raise ValueError('Shape x and y cannot broadcast against '
'each other: {} vs {}.'.format(x, y))
else:
right.append(a)

if len(x) > common_len:
left = x[:len(x)-common_len]
else:
left = y[:len(y)-common_len]
return left + right
x_len, y_len = len(x), len(y)
max_length = max(x_len, y_len)
x_ex = [1] * (max_length - x_len) + x
y_ex = [1] * (max_length - y_len) + y
for i in range(max_length):
a, b = x_ex[i], y_ex[i]
if b != a and b != 1:
if a != 1:
raise ValueError('Shape x and y cannot broadcast against '
'each other: {} vs {}.'.format(x, y))
else:
x_ex[i] = b
return x_ex


@jit
Expand Down
3 changes: 0 additions & 3 deletions tests/flows/test_act_norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import unittest
from itertools import product

import pytest

import tensorkit as tk
Expand Down

0 comments on commit fda99fd

Please sign in to comment.