diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 3ccab5bfd9c3..b03c6f6f0258 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -313,6 +313,8 @@ def extern(shape, body = fcompute(input_placeholders, output_placeholders) if isinstance(body, tvm.tir.PrimExpr): body = tvm.tir.Evaluate(body) + if not isinstance(body, tvm.tir.Stmt): + raise ValueError("Function '{}' should return PrimExpr or Stmt".format(fcompute.__name__)) op = _ffi_api.ExternOp(name, tag, attrs, inputs, input_placeholders,