Skip to content

Commit

Permalink
[Dynamo] Add operator support to run UNet2DConditionModel from diffus…
Browse files Browse the repository at this point in the history
…ers (#151)

* exp, float

* wip

* chunk, groupnorm, softmax, baddbmm, emmpty

* add interpolate, lint and format

* revert changes of import hidet at top level to minimize changes for PR

* typo

* trigger actions

* trigger actions

* dummy commit

* dummy commit

* add some optimizations to skip certain operations based on alpha beta

* add group norm test

* format

* introduce a fix to torch.compile not dumping graph IR

* Revert "introduce a fix to torch.compile not dumping graph IR"

This reverts commit a1e8df0.

* add interlolate test and group norm test

* accidental push

* remove a random newline added

---------

Co-authored-by: Xin Li <xin@centml.ai>
  • Loading branch information
xinli-git and xinli-centml committed Apr 7, 2023
1 parent 3cc75b6 commit 68faaa5
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 2 deletions.
149 changes: 149 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,67 @@ def avg_pool3d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_i
return y


@register_function(torch.nn.functional.interpolate)
def interpolate(
input: Tensor,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if len(input.shape) != 4:
raise NotImplementedError("Currently only supports 4D inputs (NCHW)")

if antialias:
raise NotImplementedError("Currently does not support antialias=True")

if recompute_scale_factor:
raise NotImplementedError("Currently does not support recompute_scale_factor=True")

if size is None == scale_factor is None:
raise ValueError("Exactly one of size or scale_factor can be None")

target_size = None
if size is not None:
if isinstance(size, int):
target_size = [size, size]
else:
if len(size) != 2:
raise ValueError("Length of \"size\" must be of type int or tuple([int, int])")
target_size = list(size)
else:
if isinstance(scale_factor, (int, float)):
target_size = [int(i * scale_factor) for i in input.shape[2:]]
else:
if len(scale_factor) != 2:
raise ValueError("Length of \"scale_factor\" must be of type int or tuple([int, int])")
target_size = [a * b for a, b in zip(input.shape[2:], scale_factor)]

supported_methods = {'nearest': 'nearest', 'bilinear': 'linear', 'bicubic': 'cubic'}
if mode not in supported_methods:
raise NotImplementedError("Mode not supported")

mode_hidet = supported_methods[mode]
if align_corners:
coordinate_transformation_mode = 'align_corners'
else:
coordinate_transformation_mode = 'pytorch_half_pixel'

return ops.resize2d(
input,
target_size,
mode_hidet,
coordinate_transformation_mode,
rounding_method='round_prefer_floor',
roi=None,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
)


@register_function(operator.truediv)
def truediv(x: Union[Tensor, int, float], y: Union[Tensor, int, float]):
import hidet
Expand All @@ -212,6 +273,7 @@ def sub(x: Tensor, y: Tensor):


@register_function(torch.nn.functional.softmax)
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, dtype=None):
if dtype is not None:
raise NotImplementedError("dtype is not None")
Expand Down Expand Up @@ -281,6 +343,30 @@ def layer_norm(
return y


@register_function(torch.nn.functional.group_norm)
def group_norm(
x: Tensor,
num_groups: int,
num_channels: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
):
if x.shape[1] != num_channels:
raise ValueError(
"num_channels does not match tensor shape at index 2, expect {} but got {}".format(num_channels, x.shape[2])
)
if num_channels % num_groups != 0:
raise ValueError("num_channels {} must be divisible by num_groups {}".format(num_channels, num_groups))

y = ops.group_norm(x, num_groups, epsilon=eps)
if weight is not None:
y = y * weight.reshape([num_channels, 1, 1])
if bias is not None:
y = y + bias.reshape([num_channels, 1, 1])
return y


@register_function(torch.tanh)
def tanh(x: Tensor):
return ops.tanh(x)
Expand Down Expand Up @@ -417,13 +503,69 @@ def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, re
return ops.full(size, fill_value, dtype=hidet_dtype, device=hidet_device)


@register_function(torch.empty)
def empty(
*size,
out=None,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
):
import hidet

if out is not None:
raise NotImplementedError("hidet: does not support torch.empty(..., out=..., ...)")
if layout not in [None, torch.strided]:
raise NotImplementedError("hidet: does not support torch.empty(..., layout=..., ...)")
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
if pin_memory:
raise NotImplementedError("hidet: does not support torch.empty(..., pin_memory=True, ...)")
if memory_format != torch.contiguous_format:
raise NotImplementedError("hidet: does not support torch.empty(..., memory_format=..., ...)")

hidet_device: Device = device_from_torch(torch_device=device)
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
if len(size) == 1 and isinstance(size[0], (tuple, list)):
size = size[0]
return hidet.empty(size, dtype=hidet_dtype, device=hidet_device)


@register_function(torch.bmm)
def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return ops.matmul(input, mat2)


@register_function(torch.baddbmm)
def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out: Optional[Tensor] = None) -> Tensor:
import hidet

if out is not None:
raise NotImplementedError("hidet: does not support torch.baddbmm(..., out=...)")

if alpha == 0 and beta == 0:
size = batch1.shape[0:2] + [batch2.shape[-1]]
return hidet.zeros(shape=size, dtype=input.dtype, device=input.device)
elif alpha == 0:
return beta * input
elif beta == 0:
return alpha * ops.matmul(batch1, batch2)

if alpha == 1 and beta == 1:
return input + ops.matmul(batch1, batch2)
elif alpha == 1:
return beta * input + ops.matmul(batch1, batch2)
elif beta == 1:
return input + alpha * ops.matmul(batch1, batch2)

return beta * input + alpha * ops.matmul(batch1, batch2)


@register_function(torch.tensor)
def torch_tensor(
data: Any, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False
Expand All @@ -445,6 +587,13 @@ def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
return ops.sigmoid(x)


@register_function(torch.exp)
def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
warnings.warn_once("hidet: does not support torch.exp(..., out=...)")
return ops.exp(x)


@register_function(torch.nn.functional.hardsigmoid)
def hardsigmoid(x: Tensor, inplace: bool):
if inplace:
Expand Down
22 changes: 22 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import List, Union
import torch
from hidet.ir.type import DataType
Expand All @@ -33,6 +34,16 @@ def tensor_cpu(self: Tensor) -> Tensor:
return self.cpu()


@register_method(torch.Tensor.float)
def tensor_float(self: Tensor) -> Tensor:
return ops.cast(self, "float32")


@register_method(torch.Tensor.half)
def tensor_half(self: Tensor) -> Tensor:
return ops.cast(self, "float16")


@register_method(torch.Tensor.to)
def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
"""
Expand Down Expand Up @@ -132,6 +143,17 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts=parts)


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
chunk_size = math.ceil(dim_size / chunks)
parts = []
for start in range(0, dim_size, chunk_size):
parts.append(min(chunk_size, dim_size - start))
assert sum(parts) == self.shape[dim]
return ops.split(self, axis=dim, parts=parts)


@register_method(torch.Tensor.squeeze)
def tensor_squeeze(self: Tensor, dim=None) -> Tensor:
if dim is None:
Expand Down
16 changes: 15 additions & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, torch_module: torch.nn.Module):

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias'))
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True))


@register_module(torch.nn.BatchNorm2d)
Expand Down Expand Up @@ -142,6 +142,20 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.GroupNorm)
class HidetGroupNorm(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GroupNorm)
return regs.group_norm(
x=x,
num_groups=self.mod.num_groups,
num_channels=self.mod.num_channels,
weight=self.param('weight'),
bias=self.param('bias'),
eps=self.mod.eps,
)


@register_module(torch.nn.Tanh)
class HidetTanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .definitions.activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
from .definitions.activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink
from .definitions.activation import softshrink, softmax, softmin, hardtanh
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm, group_norm
from .definitions.image import resize2d
from .definitions.create import full, arange, linspace
from .definitions.arithmetic import add, subtract, multiply, divide, mod, remainder, negative, positive, square
Expand Down
35 changes: 35 additions & 0 deletions python/hidet/graph/ops/definitions/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,38 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5) -> Tens
"""
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return normalize(x, dims=dims, epsilon=epsilon)


def group_norm(x: Tensor, num_groups, epsilon: float = 1e-5):
"""
Group norm.
Parameters
----------
x: Tensor
The data to be normalized.
num_groups: int
The number of groups
epsilon: float
The epsilon added to variance.
Returns
-------
ret: Tensor
The normalized tensor.
"""
# first split out the group dimension
x_shape = list(x.shape)
new_shape = x_shape[:]
grouped_rank = 1
grouped_dim = new_shape[grouped_rank]
assert grouped_dim % num_groups == 0

new_shape[grouped_rank] = int(grouped_dim // num_groups)
new_shape.insert(grouped_rank, num_groups)

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
normed = normalize(x, dims=dims, epsilon=epsilon)

return normed.reshape(x_shape)
7 changes: 7 additions & 0 deletions tests/frontends/torch/test_torch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,12 @@ def test_layer_norm(shape, normalized_shape, dtype):
check_module(torch.nn.LayerNorm(normalized_shape=normalized_shape), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize('shape', [[1, 32, 128, 128]])
@pytest.mark.parametrize('num_groups', [1, 4, 32])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_group_norm(shape, num_groups, dtype):
check_module(torch.nn.GroupNorm(num_groups=num_groups, num_channels=shape[1]), [torch.randn(shape, dtype=dtype)])


if __name__ == '__main__':
pytest.main([__file__])
25 changes: 25 additions & 0 deletions tests/operators/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import torch
from torch.nn import functional as F
import torchvision as tv
import pytest

Expand All @@ -21,6 +22,8 @@
from hidet.testing import check_binary
from hidet.graph.tensor import asarray
from hidet.utils.ort_utils import create_ort_session, ort_inference
from hidet.testing import check_torch_unary
from hidet.graph.frontend.torch import register_functions as regs


class TorchResizeModel(torch.nn.Module):
Expand Down Expand Up @@ -116,5 +119,27 @@ def test_resize2d(
np.testing.assert_allclose(actual=hidet_result_cuda, desired=torch_result, atol=2e-5, rtol=2e-5)


@pytest.mark.parametrize(
"input_size, size, scale_factor, mode",
[
[[1, 3, 32, 32], [16, 16], None, 'nearest'], # 4D, resize down, nearest
[[1, 3, 32, 32], [16, 16], None, 'bilinear'], # 4D, resize down, bilinear
[[1, 3, 32, 32], [16, 16], None, 'bicubic'], # 4D, resize down, bicubic
[[1, 3, 32, 32], [64, 64], None, 'nearest'], # 4D, resize up, nearest
[[1, 3, 32, 32], None, 0.5, 'nearest'], # 4D, resize down, nearest
],
)
def test_interpolate(input_size, size, scale_factor, mode):
dtype = 'float32'
check_torch_unary(
input_size,
lambda x: F.interpolate(x, size, scale_factor, mode),
lambda x: regs.interpolate(x, size, scale_factor, mode),
dtype=dtype,
rtol=1e-5,
atol=1e-5,
)


if __name__ == '__main__':
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/operators/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,12 @@ def test_instance_norm(shape):
check_torch_unary(shape, lambda x: F.instance_norm(x), lambda x: ops.instance_norm(x), atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize('shape, num_groups', [[[1, 32, 64], 4], [[32, 4, 32], 4], [[32, 4, 32], 1]])
def test_group_norm(shape, num_groups):
check_torch_unary(
shape, lambda x: F.group_norm(x, num_groups), lambda x: ops.group_norm(x, num_groups), atol=1e-4, rtol=1e-4
)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 68faaa5

Please sign in to comment.