Skip to content

Commit

Permalink
[Frontend][Operator] Add missing operators for dinov2 (#206)
Browse files Browse the repository at this point in the history
* add support for dinov2

* fix bugs
  • Loading branch information
yaoyaoding committed May 2, 2023
1 parent dbfc57d commit ec3bc79
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 159 deletions.
16 changes: 8 additions & 8 deletions python/hidet/graph/frontend/onnx/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,14 @@ def run(self, inputs: List[Tensor]) -> List[Tensor]:
return [
ops.resize2d(
x,
target_size[2:],
mode,
coordinate_transformation_mode,
nearest_mode,
roi,
cubic_coeff_a,
exclude_outside,
extrapolation_value,
size=target_size[2:],
method=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
rounding_method=nearest_mode,
roi=roi,
cubic_alpha=cubic_coeff_a,
cubic_exclude=exclude_outside,
extrapolation_value=extrapolation_value,
)
]
else:
Expand Down
12 changes: 9 additions & 3 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,16 @@ def _raise_exception(exception: Exception, caused_callable: Any, args, kwargs):
callable_name = caused_callable.__class__.__qualname__

filename, lineno = code.co_filename, code.co_firstlineno
lines = []
lines.append(f'{exception}, occurred when calling ')
argument_strings = []
for arg in args:
argument_strings.append(arg.signature() if isinstance(arg, Tensor) else repr(arg))
for key, value in kwargs.items():
argument_strings.append(f'{key}={value.signature() if isinstance(value, Tensor) else repr(value)}')
raise type(exception)(
f'{exception}, occurred when calling {callable_name} with \n'
f' args: {args}\n'
f' kwargs: {kwargs}\n'
f'{exception}, occurred when calling\n'
f' {callable_name}({", ".join(argument_strings)})\n'
f'{callable_name} is defined at\n'
f' File "{filename}", line {lineno}'
)
Expand Down
60 changes: 25 additions & 35 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,61 +239,50 @@ def avg_pool3d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_i
@register_function(torch.nn.functional.interpolate)
def interpolate(
input: Tensor,
size=None,
size: Union[int, Sequence[int]] = None,
scale_factor=None,
mode='nearest',
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
# please refer to the way that pytorch converts its interpolate function to onnx's resize operator
# https://github.com/pytorch/pytorch/blob/940662c4dcaa090f20e39a63a8e319a58ca1460f/torch/onnx/symbolic_helper.py#L1133
# for the details of how to convert pytorch's interpolate to hidet's resize operator as we are similar to onnx
if len(input.shape) != 4:
raise NotImplementedError("Currently only supports 4D inputs (NCHW)")

if antialias:
raise NotImplementedError("Currently does not support antialias=True")

if recompute_scale_factor:
raise NotImplementedError("Currently does not support recompute_scale_factor=True")

if size is None == scale_factor is None:
if (size is None) == (scale_factor is None):
raise ValueError("Exactly one of size or scale_factor can be None")

target_size = None
if size is not None:
if isinstance(size, int):
target_size = [size, size]
else:
if len(size) != 2:
raise ValueError("Length of \"size\" must be of type int or tuple([int, int])")
target_size = list(size)
else:
if isinstance(scale_factor, (int, float)):
target_size = [int(i * scale_factor) for i in input.shape[2:]]
else:
if len(scale_factor) != 2:
raise ValueError("Length of \"scale_factor\" must be of type int or tuple([int, int])")
target_size = [a * b for a, b in zip(input.shape[2:], scale_factor)]

supported_methods = {'nearest': 'nearest', 'bilinear': 'linear', 'bicubic': 'cubic'}
if mode not in supported_methods:
raise NotImplementedError("Mode not supported")

mode_hidet = supported_methods[mode]
if align_corners:
mode_hidet = mode
if 'cubic' in mode:
mode_hidet = 'cubic'
if 'linear' in mode:
mode_hidet = 'linear'

if mode == 'nearest':
coordinate_transformation_mode = 'asymmetric'
elif align_corners:
coordinate_transformation_mode = 'align_corners'
else:
coordinate_transformation_mode = 'pytorch_half_pixel'
coordinate_transformation_mode = 'half_pixel'

return ops.resize2d(
input,
target_size,
mode_hidet,
coordinate_transformation_mode,
rounding_method='round_prefer_floor',
size=size,
scale_factor=scale_factor,
method=mode_hidet,
coordinate_transformation_mode=coordinate_transformation_mode,
rounding_method='floor',
roi=None,
cubic_alpha=-0.75,
cubic_exclude=0,
cubic_exclude=False,
extrapolation_value=0.0,
recompute_scale_factor=recompute_scale_factor,
)


Expand Down Expand Up @@ -330,6 +319,7 @@ def softmax(x: Tensor, dim: int, _stacklevel: int = 3, dtype=None):
return ops.softmax(x, dim)


@register_function(operator.matmul)
@register_function(torch.matmul)
def matmul(x: Tensor, y: Tensor):
return ops.matmul(x, y)
Expand Down Expand Up @@ -393,8 +383,8 @@ def ones(

@register_function(torch.nn.functional.gelu)
def gelu(x: Tensor, approximate: Optional[str] = "none"):
if approximate is not None:
warnings.warn_once("approximate is not None")
if approximate is not None and approximate != "none":
warnings.warn_once("hidet: gelu with approximate {repr(approximate)} is not supported. Treat as 'none'.")
return ops.gelu(x)


Expand Down
12 changes: 12 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,21 @@ def tensor_type(self: Tensor, dtype: Union[str, torch.dtype], non_blocking: bool

@register_method(torch.Tensor.expand)
def tensor_expand(self: Tensor, *sizes: int) -> Tensor:
sizes: List[int] = list(sizes)
assert len(sizes) >= len(self.shape)
for i in range(len(sizes)):
if sizes[i] == -1:
ri = len(sizes) - 1 - i
assert ri < len(self.shape)
sizes[i] = int(self.shape[len(self.shape) - 1 - ri])
return ops.broadcast(self, sizes)


@register_method(torch.Tensor.masked_fill)
def tensor_masked_fill(self: Tensor, mask: Tensor, value: float) -> Tensor:
return ops.where(mask, ops.full([], value, dtype=self.dtype, device=self.device), self)


@register_method(torch.Tensor.flatten)
def tensor_flatten(self: Tensor, start_dim=0, end_dim=-1):
return ops.flatten(self, start_dim=start_dim, end_dim=end_dim)
21 changes: 21 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,24 @@ class HidetMish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Mish)
return regs.mish(x, self.mod.inplace)


@register_module(torch.nn.Identity)
class HidetIdentity(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Identity)
return x


@register_module(torch.nn.Upsample)
class HidetUpsample(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Upsample)
return regs.interpolate(
x,
size=self.mod.size,
scale_factor=self.mod.scale_factor,
mode=self.mod.mode,
align_corners=self.mod.align_corners,
recompute_scale_factor=self.mod.recompute_scale_factor,
)

0 comments on commit ec3bc79

Please sign in to comment.