Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Add operator support to run UNet2DConditionModel from diffusers #151

Merged
merged 22 commits into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fef2a6b
exp, float
xinli-centml Mar 24, 2023
83d553c
Merge branch 'main' of xinli-git:xinli-git/hidet into dynamo_unet
xinli-centml Mar 24, 2023
16a7cb8
wip
xinli-centml Mar 24, 2023
02d76eb
chunk, groupnorm, softmax, baddbmm, emmpty
xinli-centml Mar 27, 2023
0daa3cf
add interpolate, lint and format
xinli-centml Mar 28, 2023
7b16524
Merge branch 'hidet-org:main' into dynamo_unet
xinli-git Mar 28, 2023
ece24a0
revert changes of import hidet at top level to minimize changes for PR
xinli-centml Mar 28, 2023
eff27bb
Merge branch 'dynamo_unet' of xinli-git:xinli-git/hidet into dynamo_unet
xinli-centml Mar 28, 2023
846077d
typo
xinli-centml Mar 28, 2023
75a101b
trigger actions
xinli-centml Mar 28, 2023
48c4914
trigger actions
xinli-centml Mar 28, 2023
b39e91d
dummy commit
xinli-centml Mar 28, 2023
9c7fbd0
dummy commit
xinli-centml Mar 28, 2023
bb35fa3
add some optimizations to skip certain operations based on alpha beta
xinli-centml Apr 2, 2023
299720a
add group norm test
xinli-centml Apr 3, 2023
05ea406
format
xinli-centml Apr 3, 2023
a1e8df0
introduce a fix to torch.compile not dumping graph IR
xinli-centml Apr 3, 2023
e2655ee
Revert "introduce a fix to torch.compile not dumping graph IR"
xinli-centml Apr 3, 2023
59b5b87
Merge branch 'main' of xinli-git:xinli-git/hidet into dynamo_unet
xinli-centml Apr 6, 2023
601c47b
add interlolate test and group norm test
xinli-centml Apr 7, 2023
47b4b18
accidental push
xinli-centml Apr 7, 2023
24a9699
remove a random newline added
xinli-centml Apr 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
131 changes: 131 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,67 @@ def avg_pool3d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_i
return y


@register_function(torch.nn.functional.interpolate)
def interpolate(
input: Tensor,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if len(input.shape) != 4:
raise NotImplementedError("Currently only supports 4D inputs (NCHW)")

if antialias:
raise NotImplementedError("Currently does not support antialias=True")

if recompute_scale_factor:
raise NotImplementedError("Currently does not support recompute_scale_factor=True")

if size is None == scale_factor is None:
raise ValueError("Exactly one of size or scale_factor can be None")

target_size = None
if size is not None:
if isinstance(size, int):
target_size = [size, size]
else:
if len(size) != 2:
raise ValueError("Lengthof \"size\" must be of type int or tuple([int, int])")
target_size = list(size)
else:
if isinstance(scale_factor, float):
target_size = [i * scale_factor for i in input.shape[2:]]
else:
if len(scale_factor) != 2:
raise ValueError("Lengthof \"scale_factor\" must be of type int or tuple([int, int])")
target_size = [a * b for a, b in zip(input.shape[2:], scale_factor)]

supported_methods = {'nearest': 'nearest', 'bilinear': 'linear', 'bicubic': 'cubic'}
if mode not in supported_methods:
raise NotImplementedError("Mode not supported")

mode_hidet = supported_methods[mode]
if align_corners:
coordinate_transformation_mode = 'align_corners'
else:
coordinate_transformation_mode = 'pytorch_half_pixel'

return ops.resize2d(
input,
target_size,
mode_hidet,
coordinate_transformation_mode,
rounding_method='round',
roi=None,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
)


@register_function(operator.truediv)
def truediv(x: Union[Tensor, int, float], y: Union[Tensor, int, float]):
import hidet
Expand All @@ -212,6 +273,7 @@ def sub(x: Tensor, y: Tensor):


@register_function(torch.nn.functional.softmax)
@register_method(torch.Tensor.softmax)
def softmax(x: Tensor, dim: int, dtype=None):
if dtype is not None:
raise NotImplementedError("dtype is not None")
Expand Down Expand Up @@ -281,6 +343,30 @@ def layer_norm(
return y


@register_function(torch.nn.functional.group_norm)
def group_norm(
x: Tensor,
num_groups: int,
num_channels: int,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
):
if x.shape[1] != num_channels:
raise ValueError(
"num_channels does not match tensor shape at index 2, expect {} but got {}".format(num_channels, x.shape[2])
)
if num_channels % num_groups != 0:
raise ValueError("num_channels {} must be divisible by num_groups {}".format(num_channels, num_groups))

y = ops.group_norm(x, num_groups, num_last_dims=len(x.shape) - 2, epsilon=eps)
if weight is not None:
y = y * weight.reshape([num_channels, 1, 1])
if bias is not None:
y = y + bias.reshape([num_channels, 1, 1])
return y


@register_function(torch.tanh)
def tanh(x: Tensor):
return ops.tanh(x)
Expand Down Expand Up @@ -417,13 +503,51 @@ def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, re
return ops.full(size, fill_value, dtype=hidet_dtype, device=hidet_device)


@register_function(torch.empty)
def empty(
*size,
out=None,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
):
import hidet

if out is not None:
raise NotImplementedError("hidet: does not support torch.empty(..., out=..., ...)")
if layout not in [None, torch.strided]:
raise NotImplementedError("hidet: does not support torch.empty(..., layout=..., ...)")
if requires_grad and torch.is_grad_enabled():
warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
if pin_memory:
raise NotImplementedError("hidet: does not support torch.empty(..., pin_memory=True, ...)")
if memory_format != torch.contiguous_format:
raise NotImplementedError("hidet: does not support torch.empty(..., memory_format=..., ...)")

hidet_device: Device = device_from_torch(torch_device=device)
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
if len(size) == 1 and isinstance(size[0], (tuple, list)):
size = size[0]
return hidet.empty(size, dtype=hidet_dtype, device=hidet_device)


@register_function(torch.bmm)
def bmm(input: Tensor, mat2: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return ops.matmul(input, mat2)


@register_function(torch.baddbmm)
def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return beta * input + alpha * ops.matmul(batch1, batch2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check whether alpha==1 and beta==1 and do not perform the multiplication as much as possible.

Otherwise, we need to write some graph-level pattern rewrite rules to do this simplification.



@register_function(torch.tensor)
def torch_tensor(
data: Any, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False
Expand All @@ -445,6 +569,13 @@ def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
return ops.sigmoid(x)


@register_function(torch.exp)
def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
warnings.warn_once("hidet: does not support torch.exp(..., out=...)")
return ops.exp(x)


@register_function(torch.nn.functional.hardsigmoid)
def hardsigmoid(x: Tensor, inplace: bool):
if inplace:
Expand Down
22 changes: 22 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import List, Union
import torch
from hidet.ir.type import DataType
Expand All @@ -33,6 +34,16 @@ def tensor_cpu(self: Tensor) -> Tensor:
return self.cpu()


@register_method(torch.Tensor.float)
def tensor_float(self: Tensor) -> Tensor:
return ops.cast(self, "float32")


@register_method(torch.Tensor.half)
def tensor_half(self: Tensor) -> Tensor:
return ops.cast(self, "float16")


@register_method(torch.Tensor.to)
def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
"""
Expand Down Expand Up @@ -132,6 +143,17 @@ def tensor_split(self: Tensor, split_size, dim=0) -> List[Tensor]:
return ops.split(self, axis=dim, parts=parts)


@register_method(torch.Tensor.chunk)
def tensor_chunk(self: Tensor, chunks, dim=0) -> List[Tensor]:
dim_size = self.shape[dim]
chunk_size = math.ceil(dim_size / chunks)
parts = []
for start in range(0, dim_size, chunk_size):
parts.append(min(chunk_size, dim_size - start))
assert sum(parts) == self.shape[dim]
return ops.split(self, axis=dim, parts=parts)


@register_method(torch.Tensor.squeeze)
def tensor_squeeze(self: Tensor, dim=None) -> Tensor:
if dim is None:
Expand Down
16 changes: 15 additions & 1 deletion python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, torch_module: torch.nn.Module):

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias'))
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True))


@register_module(torch.nn.BatchNorm2d)
Expand Down Expand Up @@ -142,6 +142,20 @@ def __call__(self, x: Tensor) -> Tensor:
)


@register_module(torch.nn.GroupNorm)
class HidetGroupNorm(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GroupNorm)
return regs.group_norm(
x=x,
num_groups=self.mod.num_groups,
num_channels=self.mod.num_channels,
weight=self.param('weight'),
bias=self.param('bias'),
eps=self.mod.eps,
)


@register_module(torch.nn.Tanh)
class HidetTanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .definitions.activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
from .definitions.activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink
from .definitions.activation import softshrink, softmax, softmin, hardtanh
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm, group_norm
from .definitions.image import resize2d
from .definitions.create import full, arange, linspace
from .definitions.arithmetic import add, subtract, multiply, divide, mod, remainder, negative, positive, square
Expand Down
37 changes: 37 additions & 0 deletions python/hidet/graph/ops/definitions/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,40 @@ def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5) -> Tens
"""
dims = list(range(len(x.shape) - num_last_dims, len(x.shape)))
return normalize(x, dims=dims, epsilon=epsilon)


def group_norm(x: Tensor, num_groups, num_last_dims: int = 1, epsilon: float = 1e-5):
"""
Group norm.

Parameters
----------
x: Tensor
The data to be normalized.
num_groups: int
The number of groups
num_last_dims: int
The number of dimensions to be normalized, where the leading dimension from num_last_dims will be grouped.
epsilon: float
The epsilon added to variance.

Returns
-------
ret: Tensor
The normalized tensor.
"""
# first split out the group dimension
x_shape = list(x.shape)
new_shape = x_shape[:]
grouped_rank = 1
grouped_dim = new_shape[grouped_rank]
assert grouped_dim % num_groups == 0

new_shape[grouped_rank] = int(grouped_dim // num_groups)
new_shape.insert(grouped_rank, num_groups)

x = x.reshape(new_shape)
dims = list(range(2, len(x.shape)))
normed = normalize(x, dims=dims, epsilon=epsilon)

return normed.reshape(x_shape)