Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Adding missing operators for llama #219

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/hidet/cuda/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def malloc_async(num_bytes: int, stream: Optional[Union[Stream, cudaStream_t, in
Returns
-------
addr: int
The address of the allocated memory.
The address of the allocated memory. When the allocation failed due to insufficient memory, 0 is returned.
"""
if stream is None:
stream = current_stream()
err, addr = cudart.cudaMallocAsync(num_bytes, int(stream))
if err == cudart.cudaError_t.cudaErrorMemoryAllocation:
return 0
assert err == 0, err
return addr

Expand Down Expand Up @@ -122,6 +124,8 @@ def malloc_host(num_bytes: int) -> int:
"""
err, addr = cudart.cudaMallocHost(num_bytes)
assert err == 0, err
if err == cudart.cudaError_t.cudaErrorMemoryAllocation:
return 0
return addr


Expand Down
8 changes: 5 additions & 3 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def overload(self, func: Callable):
class Registry:
registered_modules: Dict[Type[torch.nn.Module], Type['HidetModule']] = {}
registered_functions: Dict[Callable, OverloadedFunction] = {}
registered_methods: Dict[Callable, Callable] = {}
registered_methods: Dict[Callable, OverloadedFunction] = {}


class ExpectedRegistry:
Expand Down Expand Up @@ -124,7 +124,9 @@ def decorator(hidet_func):

def register_method(method: Callable):
def decorator(hidet_method):
Registry.registered_methods[method] = hidet_method
if method not in Registry.registered_functions:
Registry.registered_methods[method] = OverloadedFunction()
Registry.registered_methods[method].overload(hidet_method)
return hidet_method

return decorator
Expand Down Expand Up @@ -306,7 +308,7 @@ def load_arg(a, env):
attr = self.graph_module
for i, atom in enumerate(target_atoms):
if not hasattr(attr, atom):
raise RuntimeError(f"Node referenced nonexistent target {target_atoms[:i]} not")
raise RuntimeError(f"Node referenced nonexistent target {target_atoms[:i]}")
attr = getattr(attr, atom)
hidet_env[node.name] = tensor_from_torch(attr) if isinstance(attr, torch.Tensor) else attr
elif node.op == "call_function":
Expand Down
121 changes: 104 additions & 17 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def neg(x: Tensor):
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, _stacklevel: int = 3, dtype=None):
if dtype is not None:
raise NotImplementedError("dtype is not None")
x = ops.cast(x, dtype_from_torch(dtype))
return ops.softmax(x, dim)


Expand Down Expand Up @@ -508,7 +508,7 @@ def relu6(x: Tensor, inplace: bool = False):
@register_function(torch.arange)
def arange(
start: Number,
end: Number,
end: Number = None,
step: Number = 1,
*,
out: Optional[Tensor] = None,
Expand All @@ -518,6 +518,9 @@ def arange(
pin_memory: Optional[bool] = False,
requires_grad: Optional[bool] = False,
):
if end is None:
end = start
start = 0
if out is not None:
raise NotImplementedError("hidet: does not support torch.arange(..., out=..., ...)")
if layout is not None:
Expand Down Expand Up @@ -661,39 +664,39 @@ def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:


@register_function(torch.nn.functional.hardsigmoid)
def hardsigmoid(x: Tensor, inplace: bool):
def hardsigmoid(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: hardsigmoid with inplace=True is not supported. Treat as inplace=False.')
return ops.hardsigmoid(x)


@register_function(torch.nn.functional.silu)
def silu(x: Tensor, inplace: bool):
def silu(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: silu with inplace=True is not supported. Treat as inplace=False.')
return ops.silu(x)


@register_function(torch.nn.functional.hardswish)
def hardswish(x: Tensor, inplace: bool):
def hardswish(x: Tensor, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: hardswish with inplace=True is not supported. Treat as inplace=False.')
return ops.hardswish(x)


@register_function(torch.nn.functional.softmin)
def softmin(x: Tensor, axis: int):
return ops.softmin(x, axis)
def softmin(x: Tensor, dim: int):
return ops.softmin(x, dim)


@register_function(torch.nn.functional.softplus)
def softplus(x: Tensor, beta: int, threshold: int):
def softplus(x: Tensor, beta: int = 1, threshold: int = 20):
return ops.softplus(x, beta, threshold)


@register_function(torch.nn.functional.softshrink)
def softshrink(x: Tensor, lambda_val: float):
return ops.softshrink(x, lambda_val)
def softshrink(x: Tensor, lambd=0.5):
return ops.softshrink(x, lambd)


@register_function(torch.nn.functional.tanhshrink)
Expand All @@ -702,8 +705,8 @@ def tanhshrink(x: Tensor):


@register_function(torch.nn.functional.hardshrink)
def hardshrink(x: Tensor, lambda_val: float):
return ops.hardshrink(x, lambda_val)
def hardshrink(x: Tensor, lambd=0.5):
return ops.hardshrink(x, lambd)


@register_function(torch.nn.functional.softsign)
Expand All @@ -712,7 +715,9 @@ def softsign(x: Tensor):


@register_function(torch.nn.functional.celu)
def celu(x: Tensor, alpha: float):
def celu(x: Tensor, alpha: float = 1.0, inplace: bool = False):
if inplace:
warnings.warn_once('hidet: celu with inplace=True is not supported. Treat as inplace=False.')
return ops.celu(x, alpha)


Expand All @@ -739,16 +744,26 @@ def gather(x: Tensor, dim: int, index: Tensor, *, sparse_grad=False, out=None):

@register_function(torch.maximum)
def maximum(x: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
a, b = x, other
if len(a.shape) == 0 and a.device.is_cpu() and b.device.is_cuda() and not a.is_symbolic():
a = a.cuda()
if len(b.shape) == 0 and b.device.is_cpu() and a.device.is_cuda() and not b.is_symbolic():
b = b.cuda()
if out is not None:
raise NotImplementedError("hidet: does not support torch.maximum(..., out=...)")
return ops.maximum(x, other)
return ops.maximum(a, b)


@register_function(torch.minimum)
def minimum(x: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
a, b = x, other
if len(a.shape) == 0 and a.device.is_cpu() and b.device.is_cuda() and not a.is_symbolic():
a = a.cuda()
if len(b.shape) == 0 and b.device.is_cpu() and a.device.is_cuda() and not b.is_symbolic():
b = b.cuda()
if out is not None:
raise NotImplementedError("hidet: does not support torch.minimum(..., out=...)")
return ops.minimum(x, other)
return ops.minimum(a, b)


@register_function(torch.max)
Expand All @@ -765,7 +780,7 @@ def torch_max_v2(
if out is not None:
raise NotImplementedError("hidet: does not support torch.max(..., out=...)")
if isinstance(other, Tensor):
return ops.maximum(x, other)
return maximum(x, other)
else:
return torch_max_v3(x, other)

Expand Down Expand Up @@ -795,7 +810,7 @@ def torch_min_v2(
if out is not None:
raise NotImplementedError("hidet: does not support torch.min(..., out=...)")
if isinstance(other, Tensor):
return ops.minimum(x, other)
return minimum(x, other)
else:
return torch_min_v3(x, other)

Expand All @@ -809,3 +824,75 @@ def torch_min_v3(
values = ops.min(x, dims=dim, keep_dim=keepdim)
indices = ops.argmin(x, dim=dim, keep_dim=keepdim)
return values, indices


@register_function(operator.lt)
def lt(a: Tensor, b: Tensor) -> Tensor:
return ops.less(a, b)


@register_function(operator.le)
def le(a: Tensor, b: Tensor) -> Tensor:
return ops.less_equal(a, b)


@register_function(operator.gt)
def gt(a: Tensor, b: Tensor) -> Tensor:
return ops.greater(a, b)


@register_function(operator.ge)
def ge(a: Tensor, b: Tensor) -> Tensor:
return ops.greater_equal(a, b)


@register_function(operator.eq)
def eq(a: Tensor, b: Tensor) -> Tensor:
return ops.equal(a, b)


@register_function(operator.ne)
def ne(a: Tensor, b: Tensor) -> Tensor:
return ops.not_equal(a, b)


@register_function(torch.rsqrt)
def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.rsqrt(..., out=...)")
return ops.rsqrt(x)


@register_function(torch.pow)
@register_method(torch.Tensor.pow)
def tensor_pow(self: Union[Tensor, Number], exponent: Union[Tensor, Number]) -> Tensor:
if isinstance(self, Tensor) and isinstance(exponent, Tensor):
return ops.pow(self, exponent)
elif isinstance(self, Tensor):
return ops.pow(self, ops.full([], value=exponent, dtype=self.dtype, device=self.device))
elif isinstance(exponent, Tensor):
return ops.pow(ops.full([], value=self, dtype=exponent.dtype, device=exponent.device), exponent)
else:
return operator.pow(self, exponent)


@register_function(torch.mean)
@register_method(torch.Tensor.mean)
def torch_mean_v1(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
output = ops.mean(x, dims=list(range(len(x.shape))), keep_dim=True)
if dtype:
output = output.astype(dtype_from_torch(dtype))
return output


@register_function(torch.mean)
@register_method(torch.Tensor.mean)
def torch_mean_v2(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.mean(..., out=...)")
output = ops.mean(x, dims=dim, keep_dim=keepdim)
if dtype:
output = output.astype(dtype_from_torch(dtype))
return output
10 changes: 10 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,13 @@ def tensor_masked_fill(self: Tensor, mask: Tensor, value: float) -> Tensor:
@register_method(torch.Tensor.flatten)
def tensor_flatten(self: Tensor, start_dim=0, end_dim=-1):
return ops.flatten(self, start_dim=start_dim, end_dim=end_dim)


@register_method(torch.Tensor.masked_fill_)
def tensor_masked_fill_(self: Tensor, mask: Tensor, value: float) -> Tensor:
return ops.where(mask, ops.full([], value, dtype=self.dtype, device=self.device), self)


@register_method(torch.Tensor.repeat)
def tensor_repeat(self: Tensor, *sizes: int) -> Tensor:
return ops.tile(self, sizes)
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __call__(self, x: Tensor) -> Tensor:

@register_module(torch.nn.AvgPool2d)
class HidetAvgPool2d(HidetModule):
def __call__(self, x=Tensor) -> Tensor:
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AvgPool2d)
return regs.avg_pool2d(
x=x,
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ir/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def build(self):
task_keys = set()
search_space = hidet.option.get_option('search_space')
for node in self.nodes:
if node.task_func is None:
if node._task_func is None:
task_key = hash(str(node.task))
if task_key in task_keys:
continue
Expand Down
6 changes: 6 additions & 0 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Dict, Any, Union
import logging

from hidet.ir.type import TensorType, DataType
from hidet.ir.expr import Var, Constant, var
Expand All @@ -22,6 +23,9 @@
from hidet.runtime.device import Device, instantiate_device


logger = logging.getLogger(__name__)


def get_operator_name(op, given_name: Optional[str] = None):
if given_name is not None:
return given_name
Expand Down Expand Up @@ -49,6 +53,8 @@ def __init__(self, inputs: List[Tensor], attributes: Dict[str, Any], task: Optio

self.outputs = self._run()

logger.debug('Operator: %s', self)

def __str__(self):
arguments = ['{}: {}{}'.format(i, t.dtype.name, t.shape) for i, t in enumerate(self.inputs)]
attributes = ['{}={}'.format(name, str(value)) for name, value in self.attrs.items()]
Expand Down
12 changes: 9 additions & 3 deletions python/hidet/graph/ops/definitions/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,18 @@ class FullOp(Operator):
def __init__(
self,
shape: Sequence[int],
value: Union[float, int, bool, Constant],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[DataType] = None,
device: Union[Device, str] = 'cpu',
):
shape = [int(v) for v in shape]
device: Device = instantiate_device(device)

if isinstance(value, Tensor):
if value.is_symbolic():
raise NotImplementedError('Currently, we do not support symbolic tensor as value in full op')
value = value.item()

if dtype is None:
if isinstance(value, int):
dtype = dtypes.int64
Expand All @@ -133,11 +139,11 @@ def __init__(

def full(
shape: Sequence[int],
value: Union[float, int, bool, Constant],
value: Union[float, int, bool, Constant, Tensor],
dtype: Optional[Union[DataType, str]] = None,
device: Union[Device, str] = 'cpu',
) -> Tensor:
return FullOp(shape, value, data_type(dtype), device).get_output(0)
return FullOp(shape, value, data_type(dtype) if dtype is not None else dtype, device).get_output(0)


def arange(start, /, stop=None, step=1, *, dtype=None, device='cpu') -> Tensor:
Expand Down