Skip to content

Commit

Permalink
updated
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Apr 10, 2020
1 parent cf81823 commit 5a0d633
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 9 deletions.
47 changes: 45 additions & 2 deletions tensorkit/backend/pytorch_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@
'add_n',

# reduce operators
'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'argmax', 'argmin', 'log_sum_exp', 'log_mean_exp',
'reduce_sum', 'reduce_sum_axis', 'reduce_mean', 'reduce_mean_axis',
'reduce_max', 'reduce_max_axis', 'reduce_min', 'reduce_min_axis',
'argmax', 'argmin', 'log_sum_exp', 'log_sum_exp_axis',
'log_mean_exp', 'log_mean_exp_axis',
# 'all', 'any',
'calculate_mean_and_var', 'l1_norm', 'l2_norm', 'norm',
'norm_except_axis', 'global_norm',
Expand Down Expand Up @@ -1252,6 +1254,11 @@ def reduce_sum(input: Tensor,
return torch.sum(input, dim=axis, keepdim=keepdims)


@jit
def reduce_sum_axis(input: Tensor, axis: int, keepdims: bool = False) -> Tensor:
return torch.sum(input, dim=axis, keepdim=keepdims)


@jit
def reduce_mean(input: Tensor,
axis: Optional[List[int]] = None,
Expand All @@ -1268,6 +1275,11 @@ def reduce_mean(input: Tensor,
return torch.mean(input, dim=axis, keepdim=keepdims)


@jit
def reduce_mean_axis(input: Tensor, axis: int, keepdims: bool = False) -> Tensor:
return torch.mean(input, dim=axis, keepdim=keepdims)


@jit
def reduce_max(input: Tensor,
axis: Optional[List[int]] = None,
Expand All @@ -1290,6 +1302,11 @@ def reduce_max(input: Tensor,
return input


@jit
def reduce_max_axis(input: Tensor, axis: int, keepdims: bool = False) -> Tensor:
return torch.max(input, dim=axis, keepdim=keepdims)[0]


@jit
def reduce_min(input: Tensor,
axis: Optional[List[int]] = None,
Expand All @@ -1311,6 +1328,11 @@ def reduce_min(input: Tensor,
input = squeeze(input, axis)
return input


@jit
def reduce_min_axis(input: Tensor, axis: int, keepdims: bool = False) -> Tensor:
return torch.min(input, dim=axis, keepdim=keepdims)[0]


@jit
def argmax(input: Tensor, axis: int, keepdims: bool = False) -> Tensor:
Expand Down Expand Up @@ -1338,6 +1360,13 @@ def log_sum_exp(input: Tensor,
return torch.logsumexp(input, dim=axis, keepdim=keepdims)


@jit
def log_sum_exp_axis(input: Tensor,
axis: int,
keepdims: bool = False) -> Tensor:
return torch.logsumexp(input, dim=axis, keepdim=keepdims)


@jit
def log_mean_exp(input: Tensor,
axis: Optional[List[int]] = None,
Expand All @@ -1355,6 +1384,20 @@ def log_mean_exp(input: Tensor,
return x_max + torch.log(mean_exp)


@jit
def log_mean_exp_axis(input: Tensor,
axis: int,
keepdims: bool = False) -> Tensor:
x_max_keepdims = reduce_max_axis(input, axis=axis, keepdims=True)
if not keepdims:
x_max = torch.squeeze(x_max_keepdims, dim=axis)
else:
x_max = x_max_keepdims
mean_exp = reduce_mean_axis(
torch.exp(input - x_max_keepdims), axis=axis, keepdims=keepdims)
return x_max + torch.log(mean_exp)


@jit
def calculate_mean_and_var(input: Tensor,
axis: Optional[List[int]] = None,
Expand Down
4 changes: 2 additions & 2 deletions tensorkit/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _sample(self,
c_samples = [c.sample(n_samples).tensor for c in self.components]
c_samples = T.stack(c_samples, axis=-1)

samples = T.reduce_sum(mask * c_samples, axis=[-1])
samples = T.reduce_sum_axis(mask * c_samples, axis=-1)

if not reparameterized:
samples = T.stop_grad(samples)
Expand All @@ -217,7 +217,7 @@ def _log_prob(self,
[c.log_prob(given) for c in self.components],
axis=-1
)
log_prob = T.log_sum_exp(cat_log_prob + c_log_prob, axis=[-1])
log_prob = T.log_sum_exp_axis(cat_log_prob + c_log_prob, axis=-1)
if reduce_ndims > 0:
log_prob = T.reduce_sum(log_prob, axis=T.int_range(-reduce_ndims, 0))
return log_prob
Expand Down
6 changes: 3 additions & 3 deletions tensorkit/gnn/adj/gcn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class GCNDense(GCNLayer):
def __init__(self,
in_features: int,
out_features: int,
use_self_loop: bool = False,
use_self_loop: bool = True,
self_weight: float = 1.,
merge_mode: Union[str, GCNMergeMode] = 'add',
use_bias: Optional[bool] = None,
Expand Down Expand Up @@ -520,7 +520,7 @@ class GCNDense(GCNLayer):
def __init__(self,
in_features: int,
out_features: int,
use_self_loop: bool = False,
use_self_loop: bool = True,
self_weight: float = 1.,
merge_mode: Union[str, GCNMergeMode] = 'add',
use_bias: Optional[bool] = None,
Expand Down Expand Up @@ -575,7 +575,7 @@ def __init__(self,
in_features: int,
out_features: int,
n_partitions: int,
use_self_loop: bool = False,
use_self_loop: bool = True,
self_weight: float = 1.,
merge_mode: Union[str, GCNMergeMode] = 'add',
use_bias: Optional[bool] = None,
Expand Down
5 changes: 3 additions & 2 deletions tensorkit/tensor/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .core import Tensor, jit, rank, length, shape, reduce_sum, reduce_mean
from .core import (Tensor, jit, rank, length, shape, reduce_sum,
reduce_sum_axis, reduce_mean, reduce_mean_axis)
from .nn import log_sigmoid, softplus

__all__ = ['negative_sampling']
Expand All @@ -18,7 +19,7 @@ def negative_sampling(pos_logits: Tensor,
format(shape(pos_logits), shape(neg_logits)))

pos_logits = log_sigmoid(pos_logits)
neg_logits = reduce_sum(softplus(neg_logits), axis=[-1])
neg_logits = reduce_sum_axis(softplus(neg_logits), axis=-1)

if negative:
output = neg_logits - pos_logits
Expand Down
7 changes: 7 additions & 0 deletions tests/tensor/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,13 @@ def log_f_exp(f, x, axis=None, keepdims=False):
match='`axis` must not be an empty list'):
_ = T_op(t, axis=[])

# test the per-axis version
T_op = getattr(T, f'reduce_{name}_axis', getattr(T, f'{name}_axis', None))
if T_op is not None:
assert_allclose(T_op(t, axis=-1), np_op(x, axis=-1))
assert_allclose(T_op(t, axis=-1, keepdims=True),
np_op(x, axis=-1, keepdims=True))

# test argmax, argmin
def np_argmaxmin(fn, x, axis, keepdims=False):
r_shape = list(x.shape)
Expand Down

0 comments on commit 5a0d633

Please sign in to comment.