Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch][Graph][Operator] Add and fix various items for torchvision model support #347

Merged
merged 14 commits into from
Aug 12, 2023
8 changes: 8 additions & 0 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def load_arg(a, env):
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if exec_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
Expand Down Expand Up @@ -448,6 +452,10 @@ def load_arg(a, env):

try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if hidet_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)

Expand Down
107 changes: 102 additions & 5 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
from hidet.ir.expr import Expr
from hidet.ir import expr
from hidet.ir.dtypes import promote_type
from hidet.ir.expr import Int
from hidet.ir.expr import Expr, Int, is_constant
from hidet.runtime.device import Device
from .interpreter import register_function, register_method
from .interpreter import warnings
Expand Down Expand Up @@ -97,6 +96,11 @@ def adaptive_avg_pool2d(x: Tensor, output_size):
return ops.adaptive_avg_pool2d(x, output_size)


@register_function(torch.nn.functional.adaptive_avg_pool3d)
def adaptive_avg_pool3d(x: Tensor, output_size):
return ops.adaptive_avg_pool3d(x, output_size)


@register_function(torch.nn.functional.relu)
def relu(x: Tensor, inplace: bool):
# if inplace:
Expand Down Expand Up @@ -130,7 +134,9 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=


@register_function(torch.nn.functional.linear)
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]):
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False):
if len(weight.shape) > 1 and not weight_is_transposed:
weight = ops.transpose(weight, [1, 0])
y = ops.matmul(x, weight)
if bias is not None:
y = y + bias
Expand Down Expand Up @@ -205,10 +211,18 @@ def batch_norm(
)
y = ops.batch_norm_infer(x, running_mean, running_var, epsilon=eps)
_ = momentum # unused
if len(x.shape) == 3:
dims = [0, 2]
if len(x.shape) == 4:
dims = [0, 2, 3]
elif len(x.shape) == 5:
dims = [0, 2, 3, 4]
else:
raise NotImplementedError("batch_norm only accepts 3D, 4D, 5D input")
if weight is not None:
y = y * weight.unsqueeze([0, 2, 3])
y = y * weight.unsqueeze(dims)
if bias is not None:
y = y + bias.unsqueeze([0, 2, 3])
y = y + bias.unsqueeze(dims)
return y


Expand All @@ -222,6 +236,70 @@ def getitem(x: Tensor, index):
return x[index]


@register_function(operator.setitem)
def setitem(x: Tensor, item, setvalue):

if isinstance(item, list):
item = tuple(item)
if not isinstance(item, tuple):
item = tuple([item])

if not isinstance(setvalue, (int, float)):
raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values')

# now, the item could have
# 1. integer index
# 2. slice
# 3. Ellipsis
# 4. None
# e.g., [1, 3:5, ..., None]

# process Ellipsis
# e.g., x[1, ..., 2] -> x[1, :, :, 2]
if Ellipsis in item:
if item.count(Ellipsis) > 1:
raise ValueError('Only one ellipsis allowed in index.')
ellipsis_index = item.index(Ellipsis)
ellipsis_ndim = len(x.shape) - sum([1 if axis not in [None, Ellipsis] else 0 for axis in item])
ellipsis_ndim = max(ellipsis_ndim, 0)
item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :]

# normalize index
normalized_item = []
for i, v in enumerate(item):
if isinstance(v, int):
if v < 0:
v = v + x.shape[i]
if is_constant(v, x.shape[i]) and (v < 0 or v >= x.shape[i]):
raise IndexError('index {} is out of bound for dimension {} with size {}'.format(v, i, x.shape[i]))
normalized_item.append(v)
elif v is not None:
# None affects getitem, but is ignored in setitem
normalized_item.append(v)
item = tuple(normalized_item)

# process slice and integer index
rank = len(x.shape)
while len(item) < rank:
item = item + (slice(None),)
starts, ends, steps = [], [], []
squeeze_dims = []
for dim, v in enumerate(item):
if isinstance(v, (int, Expr)):
squeeze_dims.append(dim)
starts.append(v)
ends.append(v + 1)
steps.append(1)
else:
assert isinstance(v, slice)
starts.append(v.start)
ends.append(v.stop)
steps.append(v.step)

out = ops.set_strided_slice(x, starts, ends, steps, setvalue)
return out


@register_function(operator.mul)
@register_function(torch.mul)
@register_function(torch.ops.aten.mul.Tensor)
Expand Down Expand Up @@ -931,6 +1009,13 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor

@register_function(operator.eq)
def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
if isinstance(a, Tensor) or isinstance(b, Tensor):
from hidet.graph.ops.utils import convert_to_tensor

if isinstance(a, Tensor):
return ops.equal(a, convert_to_tensor(b, a))
else:
return ops.equal(b, convert_to_tensor(a, b))
return a == b


Expand Down Expand Up @@ -1104,3 +1189,15 @@ def clamp(
@register_function(torch.isinf)
def isinf(x: Tensor) -> Tensor:
return ops.isinf(x)


@register_function(torch.nn.functional.pad)
def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0):
if isinstance(pad, tuple):
pad = list(pad)
return ops.pad(x, pads=pad, mode=mode, value=value)


@register_function(torch.roll)
def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None):
return ops.roll(x, shifts, dims)
24 changes: 24 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,27 @@ def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor:
@register_method(torch.Tensor.all)
def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.all(self, axis=dim, keepdims=keepdim)


@register_method(torch.Tensor.matmul)
def tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
return ops.matmul(self, other)


@register_method(torch.Tensor.new_zeros)
def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
if layout is not None:
raise NotImplementedError("layout is not None")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = size
if dtype is None:
dtype = self.dtype
if device is None:
device = self.device

_ = pin_memory
_ = requires_grad

return ops.full(shape, dtype=dtype, device=device, value=dtype.zero)
85 changes: 80 additions & 5 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
from __future__ import annotations
import torch
from hidet.graph import ops
from hidet.graph.tensor import Tensor
from .interpreter import HidetModule, register_module
from . import register_functions as regs
Expand Down Expand Up @@ -117,6 +118,13 @@ def __call__(self, x: Tensor) -> Tensor:
return regs.adaptive_avg_pool2d(x, self.mod.output_size)


@register_module(torch.nn.AdaptiveAvgPool3d)
class HidetAdaptiveAvgPool3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AdaptiveAvgPool3d)
return regs.adaptive_avg_pool3d(x, self.mod.output_size)


@register_module(torch.nn.ReLU)
class HidetReLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -158,21 +166,21 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetLinear(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
from hidet import ops

steal = dynamo_config['steal_weights']

self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True))
return regs.linear(
x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True
)


@register_module(torch.nn.BatchNorm2d)
@register_module(torch.nn.BatchNorm3d)
class HidetBatchNorm2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.BatchNorm2d)
assert isinstance(self.mod, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d))
return regs.batch_norm(
x=x,
running_mean=self.param('running_mean'),
Expand Down Expand Up @@ -404,3 +412,70 @@ def __call__(self, x: Tensor) -> Tensor:
align_corners=self.mod.align_corners,
recompute_scale_factor=self.mod.recompute_scale_factor,
)


@register_module(torch.nn.MultiheadAttention)
class HidetMultiheadAttention(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
steal = dynamo_config['steal_weights']
self.in_proj_weight_transposed = ops.transpose(self.param('in_proj_weight', steal=steal), [1, 0])
self.out_proj_weight_transposed = ops.transpose(self.param('out_proj.weight', steal=steal), [1, 0])

def __call__(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
is_causal=False,
) -> Tensor:
assert isinstance(self.mod, torch.nn.MultiheadAttention)
supported = (
self.mod._qkv_same_embed_dim
and self.mod.bias_k is None
and self.mod.bias_v is None
and not self.mod.add_zero_attn
and self.mod.batch_first
and key_padding_mask is None
and not need_weights
)
if not supported:
raise NotImplementedError(
"Hidet Multihead Attention currently only supports "
"kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False, "
"batch_first=True, forward(key_padding_mask=None, need_weights=False)."
)

# Input feed forward
wq, wk, wv = ops.split(self.in_proj_weight_transposed, parts_or_sections=3, axis=1)
query = ops.matmul(query, wq)
key = ops.matmul(key, wk)
value = ops.matmul(value, wv)
if self.mod.in_proj_bias is not None:
bq, bk, bv = ops.split(self.param('in_proj_bias'), parts_or_sections=3, axis=0)
query = ops.add(query, bq)
key = ops.add(key, bk)
value = ops.add(value, bv)

# Split heads
split_head_dims = [query.shape[0], query.shape[1], self.mod.num_heads, query.shape[2] // self.mod.num_heads]
query = ops.transpose(query.reshape(split_head_dims), [0, 2, 1, 3])
key = ops.transpose(key.reshape(split_head_dims), [0, 2, 1, 3])
value = ops.transpose(value.reshape(split_head_dims), [0, 2, 1, 3])

# fmha
out = regs.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=self.mod.dropout, is_causal=is_causal
)

# Output feed forward
merge_head_dims = [out.shape[0], out.shape[2], self.mod.embed_dim]
out = ops.transpose(out, [0, 2, 1, 3]).reshape(merge_head_dims)
out = ops.matmul(out, self.out_proj_weight_transposed)
if self.mod.out_proj.bias is not None:
out = ops.add(out, self.param('out_proj.bias'))
return out
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf
from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or
from .arithmetic import bitwise_xor, maximum, minimum, clamp
from .arithmetic import isfinite, isinf, isnan, sign, where
from .arithmetic import isfinite, isinf, isnan, sign, where, set_strided_slice, roll
from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2
from .complex import real, imag, conj, make_complex
from .compare import equal, not_equal, less, greater, less_equal, greater_equal
Expand Down