Skip to content

Commit

Permalink
[PYTORCH]Unary Ops (apache#5378)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent 8c8b5b6 commit 5266afc
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 123 deletions.
96 changes: 26 additions & 70 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@ def _impl(inputs, input_types):
return get_relay_op(name)(data0, data1)
return _impl

def _abs():

def _unary(name):
def _impl(inputs, input_types):
data = inputs[0]
return _op.abs(data)
input_type = input_types[0]
data = _convert_elemwise_input(inputs[0], input_type)

return get_relay_op(name)(data)
return _impl


def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
Expand Down Expand Up @@ -1260,26 +1264,6 @@ def _impl(inputs, input_types):
return _op.nn.pad(data, pad_width, pad_value)
return _impl

def _sqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.sqrt(data)
return _impl


def _rsqrt():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.rsqrt(data)
return _impl


def _ceil():
def _impl(inputs, input_types):
data = inputs[0]
return _op.ceil(data)
return _impl


def _clamp():
def _impl(inputs, input_types):
Expand All @@ -1290,20 +1274,6 @@ def _impl(inputs, input_types):
return _impl


def _floor():
def _impl(inputs, input_types):
data = inputs[0]
return _op.floor(data)
return _impl


def _round():
def _impl(inputs, input_types):
data = inputs[0]
return _op.round(data)
return _impl


def _to():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1381,17 +1351,6 @@ def _impl(inputs, input_types):
return inputs[0]
return _impl

def _neg():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.negative(data)
return _impl

def _tanh():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.tanh(data)
return _impl

def _Bool():
def _impl(inputs, input_types):
Expand Down Expand Up @@ -1473,18 +1432,6 @@ def _impl(inputs, input_types):
return _impl


def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
return _impl


def _isnan():
def _impl(inputs, input_types):
return _op.isnan(inputs[0])
return _impl


def _list_getitem(prelude):
def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
Expand Down Expand Up @@ -1607,7 +1554,6 @@ def _get_convert_map(prelude):
"aten::mul" : _elemwise("multiply"),
"aten::mul_" : _elemwise("multiply"),
"aten::pow" : _elemwise("power"),
"aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
Expand Down Expand Up @@ -1689,12 +1635,26 @@ def _get_convert_map(prelude):
"aten::argmax" : _reduce("argmax"),
"aten::std" : _std(),
"aten::var" : _variance(),
"aten::sqrt" : _sqrt(),
"aten::rsqrt" : _rsqrt(),
"aten::ceil" : _ceil(),
"aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"),
"aten::sin" : _unary("sin"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
"aten::sign" : _unary("sign"),
"aten::sqrt" : _unary("sqrt"),
"aten::rsqrt" : _unary("rsqrt"),
"aten::ceil" : _unary("ceil"),
"aten::floor" : _unary("floor"),
"aten::round" : _unary("round"),
"aten::isfinite" : _unary("isfinite"),
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::floor" : _floor(),
"aten::round" : _round(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
Expand All @@ -1709,12 +1669,8 @@ def _get_convert_map(prelude):
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::mm" : _matmul(),
Expand Down
141 changes: 88 additions & 53 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,30 +1508,6 @@ def forward(self, *args):
verify_model(IsInf1().float().eval(), input_data=input_data)


def test_forward_rsqrt():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Rsqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Rsqrt1().float().eval(), input_data=input_data)


def test_forward_ceil():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Ceil1().float().eval(), input_data=input_data)


def test_forward_clamp():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand All @@ -1554,30 +1530,6 @@ def forward(self, *args):
verify_model(Clamp3().float().eval(), input_data=input_data)


def test_forward_floor():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Floor1().float().eval(), input_data=input_data)


def test_forward_round():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Round1(Module):
def forward(self, *args):
return torch.round(args[0])

input_data = torch.rand(input_shape).float()
verify_model(Round1().float().eval(), input_data=input_data)


def test_forward_ones():
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -1860,6 +1812,93 @@ def forward(self, *args):
verify_model(LogicalXor2().float().eval(), input_data=[lhs])


def test_forward_unary():
torch.set_grad_enabled(False)

class Sqrt1(Module):
def forward(self, *args):
return torch.sqrt(args[0])

class RSqrt1(Module):
def forward(self, *args):
return torch.rsqrt(args[0])

class Ceil1(Module):
def forward(self, *args):
return torch.ceil(args[0])

class Floor1(Module):
def forward(self, *args):
return torch.floor(args[0])

class Round1(Module):
def forward(self, *args):
return torch.round(args[0])

class Cos1(Module):
def forward(self, *args):
return torch.cos(args[0])

class Sin1(Module):
def forward(self, *args):
return torch.sin(args[0])

class Tan1(Module):
def forward(self, *args):
return torch.tan(args[0])

class Tanh1(Module):
def forward(self, *args):
return torch.tanh(args[0])

class ATanh1(Module):
def forward(self, *args):
return torch.atan(args[0])

class Log1(Module):
def forward(self, *args):
return torch.log(args[0])

class Exp1(Module):
def forward(self, *args):
return torch.exp(args[0])

class Erf1(Module):
def forward(self, *args):
return torch.erf(args[0])

class Trunc1(Module):
def forward(self, *args):
return torch.trunc(args[0])

class Sign1(Module):
def forward(self, *args):
return torch.sign(args[0])

class Neg1(Module):
def forward(self, *args):
return torch.neg(args[0])

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
verify_model(RSqrt1().float().eval(), input_data=input_data)
verify_model(Ceil1().float().eval(), input_data=input_data)
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
verify_model(Sign1().float().eval(), input_data=input_data)
verify_model(Neg1().float().eval(), input_data=input_data)


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1918,12 +1957,8 @@ def forward(self, *args):
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_abs()
test_forward_rsqrt()
test_forward_ceil()
test_forward_unary()
test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
Expand Down

0 comments on commit 5266afc

Please sign in to comment.