Skip to content

Commit

Permalink
[Operator] Adding missing operators for llama (#219)
Browse files Browse the repository at this point in the history
adding missing operators for llama
  • Loading branch information
yaoyaoding committed May 9, 2023
1 parent b8d5e0e commit e6d89a7
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 41 deletions.
6 changes: 5 additions & 1 deletion python/hidet/cuda/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def malloc_async(num_bytes: int, stream: Optional[Union[Stream, cudaStream_t, in
Returns
-------
addr: int
The address of the allocated memory.
The address of the allocated memory. When the allocation failed due to insufficient memory, 0 is returned.
"""
if stream is None:
stream = current_stream()
err, addr = cudart.cudaMallocAsync(num_bytes, int(stream))
if err == cudart.cudaError_t.cudaErrorMemoryAllocation:
return 0
assert err == 0, err
return addr

Expand Down Expand Up @@ -122,6 +124,8 @@ def malloc_host(num_bytes: int) -> int:
"""
err, addr = cudart.cudaMallocHost(num_bytes)
assert err == 0, err
if err == cudart.cudaError_t.cudaErrorMemoryAllocation:
return 0
return addr


Expand Down
8 changes: 5 additions & 3 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def overload(self, func: Callable):
class Registry:
registered_modules: Dict[Type[torch.nn.Module], Type['HidetModule']] = {}
registered_functions: Dict[Callable, OverloadedFunction] = {}
registered_methods: Dict[Callable, Callable] = {}
registered_methods: Dict[Callable, OverloadedFunction] = {}


class ExpectedRegistry:
Expand Down Expand Up @@ -124,7 +124,9 @@ def decorator(hidet_func):

def register_method(method: Callable):
def decorator(hidet_method):
Registry.registered_methods[method] = hidet_method
if method not in Registry.registered_functions:
Registry.registered_methods[method] = OverloadedFunction()
Registry.registered_methods[method].overload(hidet_method)
return hidet_method

return decorator
Expand Down Expand Up @@ -306,7 +308,7 @@ def load_arg(a, env):
attr = self.graph_module
for i, atom in enumerate(target_atoms):
if not hasattr(attr, atom):
raise RuntimeError(f"Node referenced nonexistent target {target_atoms[:i]} not")
raise RuntimeError(f"Node referenced nonexistent target {target_atoms[:i]}")
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
Expand Down
121 changes: 104 additions & 17 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def neg(x: Tensor):
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, _stacklevel: int = 3, dtype=None):
if dtype is not None:
raise NotImplementedError("dtype is not None")
x = ops.cast(x, dtype_from_torch(dtype))
return ops.softmax(x, dim)


Expand Down Expand Up @@ -508,7 +508,7 @@ def relu6(x: Tensor, inplace: bool = False):
@register_function(torch.arange)
def arange(
start: Number,
end: Number,
end: Number = None,
step: Number = 1,
*,
out: Optional[Tensor] = None,
Expand All @@ -518,6 +518,9 @@ def arange(
pin_memory: Optional[bool] = False,
requires_grad: Optional[bool] = False,
):
if end is None:
end = start
start = 0
if out is not None:
raise NotImplementedError("hidet: does not support torch.arange(..., out=..., ...)")
if layout is not None:
Expand Down Expand Up @@ -661,39 +664,39 @@ def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:


@register_function(torch.nn.functional.hardsigmoid)
def hardsigmoid(x: Tensor, inplace: bool):
def hardsigmoid(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: hardsigmoid with inplace=True is not supported. Treat as inplace=False.')
return ops.hardsigmoid(x)


@register_function(torch.nn.functional.silu)
def silu(x: Tensor, inplace: bool):
def silu(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: silu with inplace=True is not supported. Treat as inplace=False.')
return ops.silu(x)


@register_function(torch.nn.functional.hardswish)
def hardswish(x: Tensor, inplace: bool):
def hardswish(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: hardswish with inplace=True is not supported. Treat as inplace=False.')
return ops.hardswish(x)


@register_function(torch.nn.functional.softmin)
def softmin(x: Tensor, axis: int):
return ops.softmin(x, axis)
def softmin(x: Tensor, dim: int):
return ops.softmin(x, dim)


@register_function(torch.nn.functional.softplus)
def softplus(x: Tensor, beta: int, threshold: int):
def softplus(x: Tensor, beta: int = 1, threshold: int = 20):
return ops.softplus(x, beta, threshold)


@register_function(torch.nn.functional.softshrink)
def softshrink(x: Tensor, lambda_val: float):
return ops.softshrink(x, lambda_val)
def softshrink(x: Tensor, lambd=0.5):
return ops.softshrink(x, lambd)


@register_function(torch.nn.functional.tanhshrink)
Expand All @@ -702,8 +705,8 @@ def tanhshrink(x: Tensor):


@register_function(torch.nn.functional.hardshrink)
def hardshrink(x: Tensor, lambda_val: float):
return ops.hardshrink(x, lambda_val)
def hardshrink(x: Tensor, lambd=0.5):
return ops.hardshrink(x, lambd)


@register_function(torch.nn.functional.softsign)
Expand All @@ -712,7 +715,9 @@ def softsign(x: Tensor):


@register_function(torch.nn.functional.celu)
def celu(x: Tensor, alpha: float):
def celu(x: Tensor, alpha: float = 1.0, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: celu with inplace=True is not supported. Treat as inplace=False.')
return ops.celu(x, alpha)


Expand All @@ -739,16 +744,26 @@ def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):

@register_function(torch.maximum)
def maximum(x: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
a, b = x, other
if len(a.shape) == 0 and a.device.is_cpu() and b.device.is_cuda() and not a.is_symbolic():
a = a.cuda()
if len(b.shape) == 0 and b.device.is_cpu() and a.device.is_cuda() and not b.is_symbolic():
b = b.cuda()
if out is not None:
raise NotImplementedError("hidet: does not support torch.maximum(..., out=...)")
return ops.maximum(x, other)
return ops.maximum(a, b)


@register_function(torch.minimum)
def minimum(x: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
a, b = x, other
if len(a.shape) == 0 and a.device.is_cpu() and b.device.is_cuda() and not a.is_symbolic():
a = a.cuda()
if len(b.shape) == 0 and b.device.is_cpu() and a.device.is_cuda() and not b.is_symbolic():
b = b.cuda()
if out is not None:
raise NotImplementedError("hidet: does not support torch.minimum(..., out=...)")
return ops.minimum(x, other)
return ops.minimum(a, b)


@register_function(torch.max)
Expand All @@ -765,7 +780,7 @@ def torch_max_v2(
if out is not None:
raise NotImplementedError("hidet: does not support torch.max(..., out=...)")
if isinstance(other, Tensor):
return ops.maximum(x, other)
return maximum(x, other)
else:
return torch_max_v3(x, other)

Expand Down Expand Up @@ -795,7 +810,7 @@ def torch_min_v2(
if out is not None:
raise NotImplementedError("hidet: does not support torch.min(..., out=...)")
if isinstance(other, Tensor):
return ops.minimum(x, other)
return minimum(x, other)
else:
return torch_min_v3(x, other)

Expand All @@ -809,3 +824,75 @@ def torch_min_v3(
values = ops.min(x, dims=dim, keep_dim=keepdim)
indices = ops.argmin(x, dim=dim, keep_dim=keepdim)
return values, indices


@register_function(operator.lt)
def lt(a: Tensor, b: Tensor) -> Tensor:
return ops.less(a, b)


@register_function(operator.le)
def le(a: Tensor, b: Tensor) -> Tensor:
return ops.less_equal(a, b)


@register_function(operator.gt)
def gt(a: Tensor, b: Tensor) -> Tensor:
return ops.greater(a, b)


@register_function(operator.ge)
def ge(a: Tensor, b: Tensor) -> Tensor:
return ops.greater_equal(a, b)


@register_function(operator.eq)
def eq(a: Tensor, b: Tensor) -> Tensor:
return ops.equal(a, b)


@register_function(operator.ne)
def ne(a: Tensor, b: Tensor) -> Tensor:
return ops.not_equal(a, b)


@register_function(torch.rsqrt)
def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.rsqrt(..., out=...)")
return ops.rsqrt(x)


@register_function(torch.pow)
@register_method(torch.Tensor.pow)
def tensor_pow(self: Union[Tensor, Number], exponent: Union[Tensor, Number]) -> Tensor:
if isinstance(self, Tensor) and isinstance(exponent, Tensor):
return ops.pow(self, exponent)
elif isinstance(self, Tensor):
return ops.pow(self, ops.full([], value=exponent, dtype=self.dtype, device=self.device))
elif isinstance(exponent, Tensor):
return ops.pow(ops.full([], value=self, dtype=exponent.dtype, device=exponent.device), exponent)
else:
return operator.pow(self, exponent)


@register_function(torch.mean)
@register_method(torch.Tensor.mean)
def torch_mean_v1(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
output = ops.mean(x, dims=list(range(len(x.shape))), keep_dim=True)
if dtype:
output = output.astype(dtype_from_torch(dtype))
return output


@register_function(torch.mean)
@register_method(torch.Tensor.mean)
def torch_mean_v2(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.mean(..., out=...)")
output = ops.mean(x, dims=dim, keep_dim=keepdim)
if dtype:
output = output.astype(dtype_from_torch(dtype))
return output
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,13 @@ def tensor_masked_fill(self: Tensor, mask: Tensor, value: float) -> Tensor:
@register_method(torch.Tensor.flatten)
def tensor_flatten(self: Tensor, start_dim=0, end_dim=-1):
return ops.flatten(self, start_dim=start_dim, end_dim=end_dim)


@register_method(torch.Tensor.masked_fill_)
def tensor_masked_fill_(self: Tensor, mask: Tensor, value: float) -> Tensor:
return ops.where(mask, ops.full([], value, dtype=self.dtype, device=self.device), self)


@register_method(torch.Tensor.repeat)
def tensor_repeat(self: Tensor, *sizes: int) -> Tensor:
return ops.tile(self, sizes)
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __call__(self, x: Tensor) -> Tensor:

@register_module(torch.nn.AvgPool2d)
class HidetAvgPool2d(HidetModule):
def __call__(self, x=Tensor) -> Tensor:
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AvgPool2d)
return regs.avg_pool2d(
x=x,
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def build(self):
task_keys = set()
search_space = hidet.option.get_option('search_space')
for node in self.nodes:
if node.task_func is None:
if node._task_func is None:
task_key = hash(str(node.task))
if task_key in task_keys:
continue
Expand Down
6 changes: 6 additions & 0 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Dict, Any, Union
import logging

from hidet.ir.type import TensorType, DataType
from hidet.ir.expr import Var, Constant, var
Expand All @@ -22,6 +23,9 @@
from hidet.runtime.device import Device, instantiate_device


logger = logging.getLogger(__name__)


def get_operator_name(op, given_name: Optional[str] = None):
if given_name is not None:
return given_name
Expand Down Expand Up @@ -49,6 +53,8 @@ def __init__(self, inputs: List[Tensor], attributes: Dict[str, Any], task: Optio

self.outputs = self._run()

logger.debug('Operator: %s', self)

def __str__(self):
arguments = ['{}: {}{}'.format(i, t.dtype.name, t.shape) for i, t in enumerate(self.inputs)]
attributes = ['{}={}'.format(name, str(value)) for name, value in self.attrs.items()]
Expand Down
12 changes: 9 additions & 3 deletions python/hidet/graph/ops/definitions/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,18 @@ class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
value: Union[float, int, bool, Constant],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
if value.is_symbolic():
raise NotImplementedError('Currently, we do not support symbolic tensor as value in full op')
value = value.item()

if dtype is None:
if isinstance(value, int):
dtype = dtypes.int64
Expand All @@ -133,11 +139,11 @@ def __init__(

def full(
shape: Sequence[int],
value: Union[float, int, bool, Constant],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[Union[DataType, str]] = None,
device: Union[Device, str] = 'cpu',
) -> Tensor:
return FullOp(shape, value, data_type(dtype), device).get_output(0)
return FullOp(shape, value, data_type(dtype) if dtype is not None else dtype, device).get_output(0)


def arange(start, /, stop=None, step=1, *, dtype=None, device='cpu') -> Tensor:
Expand Down

0 comments on commit e6d89a7

Please sign in to comment.