Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Support] Add support for wav2vec #303

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions gallery/how-to-guides/add-subgraph-rewrite-rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:

# %%
# We can check that the rewrite rule has been registered:
from hidet.graph.transforms import registered_rewrite_rules
from hidet.graph.transforms import (
registered_rewrite_rules,
clear_registered_rewrite_rules,
)

print('Registered rewrite rules:')
for rule in registered_rewrite_rules:
for rule in registered_rewrite_rules():
assert isinstance(rule, SubgraphRewriteRule)
print(rule.name)

Expand All @@ -146,7 +149,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
# Besides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the
# last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered
# rewrite rules and then register the rewrite rule we just defined:
registered_rewrite_rules.clear()
clear_registered_rewrite_rules()
register_rewrite_rule(
FuseTwoMatmulRewriteRule()
) # a second way to register the rewrite rule
Expand Down
22 changes: 22 additions & 0 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,33 @@ def __init__(self, torch_module: torch.nn.Module):
def __call__(self, *args, **kwargs):
raise NotImplementedError()

def _get_weight_norm_hook(self, name: str):
from torch.nn.utils.weight_norm import WeightNorm

for hook in self.mod._forward_pre_hooks.values(): # pylint: disable=protected-access
if isinstance(hook, WeightNorm) and hook.name == name:
return hook
return None

def _used_weight_norm(self, name: str) -> bool:
return self._get_weight_norm_hook(name) is not None

def _compute_weight_norm(self, name: str) -> Tensor:
hook = self._get_weight_norm_hook(name)
return hook.compute_weight(self.mod)

def param(self, name: str, optional=False) -> Optional[Tensor]:
if name not in self.torch_params:
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
# to learn more about weight norm.
if self._used_weight_norm(name):
self.torch_params[name] = self._compute_weight_norm(name)
return self.param(name, optional)

if optional:
return None
raise RuntimeError(f"hidet: {self.mod} has no parameter/buffer {name}")

if name not in self.hidet_params:
if self.torch_params[name] is None:
self.hidet_params[name] = None
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .conv1d import conv1d
from .matmul import batch_matmul, matmul, matmul_x86
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import (
conv2d,
Expand All @@ -24,7 +25,6 @@
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .conv3d import conv3d, conv3d_gemm
from .conv3d_transpose import conv3d_transpose
from .matmul import batch_matmul, matmul, matmul_x86
from .pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
from .activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/conv1d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# limitations under the License.
from .conv1d import conv1d
from .conv1d import Conv1dOp
from .conv1d_gemm import conv1d_gemm

from . import resolve
93 changes: 93 additions & 0 deletions python/hidet/graph/ops/conv1d/conv1d_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.graph.ops.utils import Operator, input_like
from hidet.graph.ops.utils import normalize_kernel, normalize_stride
from hidet.graph.tensor import Tensor
from hidet.ir.compute import TensorNode
from hidet.ir.compute import compute
from hidet.ir.expr import is_constant
from hidet.ir.task import Task
from .utils import infer_conv1d_shape


class Conv1dGemmImageTransformTask(Task):
def __init__(self, x: TensorNode, kernel: int, stride: int, dilation: int, groups: int):
n, c, h = x.shape
kx = kernel
sx = stride
dilx = dilation
p = (h - dilx * (kx - 1) - 1) // sx + 1
self._assert(
c % groups == 0,
msg='Conv1d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups),
)
gc = c // groups # group channels
gemm_x = compute(
name='gemm_x',
shape=[groups, n * p, gc * kx],
fcompute=lambda g, i, k: x[i // p, g * gc + k // kx, i % p * sx + k % kx * dilx],
)
super().__init__(name='conv1d_gemm_image_transform', inputs=[x], outputs=[gemm_x])


class Conv1dGemmImageTransformOp(Operator):
def __init__(self, x: Tensor, kernel, stride, dilations, groups):
(kernel,) = normalize_kernel(kernel, dim=1)
(stride,) = normalize_stride(stride, dim=1)
super().__init__(
inputs=[x],
attributes={'kernel': kernel, 'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv1dGemmImageTransformTask(input_like(x, 'x'), kernel, stride, dilations, groups),
)


def conv1d_gemm_image_transform(x: Tensor, kernel: int, stride: int, dilation: int, groups: int = 1) -> Tensor:
return Conv1dGemmImageTransformOp(x, kernel, stride, dilation, groups).outputs[0]


def conv1d_gemm_filter_transform(w: Tensor, groups: int = 1) -> Tensor:
# weight shape: [oc, c, kx]
# output shape: [groups, c * kx, ogc] where ogc = oc // groups
oc, c, kx = w.shape
# TODO: current assertion mechanism does not cover this use case (only on the task-level)
if is_constant(oc, groups) and oc % groups != 0:
raise ValueError('invalid conv1d groups {} for out channels {}'.format(groups, oc))
ogc = oc // groups
w = w.reshape([groups, ogc, c, kx]) # [groups, ogc, c, kx]
w = w.rearrange([[0], [2, 3], [1]]) # [groups, c * kx, ogc]
return w


def conv1d_gemm_inverse_transform(gemm_y: Tensor, out_height) -> Tensor:
# gemm_y shape: [groups, n * p, ogc]
# output shape: [n, oc, p] where oc = groups * ogc
p = out_height
groups, npq, ogc = gemm_y.shape
# TODO: current assertion mechanism does not cover this use case (only on the task-level)
if is_constant(npq, p) and npq % p != 0:
raise ValueError('invalid conv1d output shape {} for dimension {}'.format(npq, p))
n = npq // p
y = gemm_y.reshape([groups, n, p, ogc])
y = y.rearrange([[1], [0, 3], [2]])
return y


def conv1d_gemm(data: Tensor, weight: Tensor, stride, dilation: int = 1, groups: int = 1) -> Tensor:
from hidet import ops

gemm_x = conv1d_gemm_image_transform(data, kernel=weight.shape[2], stride=stride, dilation=dilation, groups=groups)
gemm_w = conv1d_gemm_filter_transform(weight, groups=groups)
gemm_y = ops.matmul(gemm_x, gemm_w, require_prologue=True)

y_shape = infer_conv1d_shape(data.shape, weight.shape, stride, groups, dilation)
y = conv1d_gemm_inverse_transform(gemm_y, out_height=y_shape[2])
return y
28 changes: 28 additions & 0 deletions python/hidet/graph/ops/conv1d/resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List, Optional
from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
from hidet.graph import ops
from hidet.ir.expr import is_constant

from .conv1d import Conv1dOp


@register_resolve_rule(Conv1dOp)
class Conv1dResolveRule(ResolveRule):
def __init__(self, enable_winograd=False):
self.enable_winograd = enable_winograd

def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, Conv1dOp)
(stride,) = ops.utils.normalize_stride(op.attrs['stride'], dim=1)
groups = op.attrs['groups']
(dilations,) = op.attrs['dilations']
channels = op.inputs[1].shape[0]

if is_constant(channels) and groups == channels:
return None # use depthwise schedule in the default Task

data, weight = op.inputs
# implicit gemm algorithm
out = ops.conv1d_gemm(data, weight, stride, dilations, groups)
return [out]
31 changes: 31 additions & 0 deletions python/hidet/graph/ops/conv1d/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence
from hidet.ir.expr import is_constant
from ..utils import normalize_stride


def infer_conv1d_shape(
x_shape: Sequence[int], w_shape: Sequence[int], stride: int, groups: int, dilation: int
) -> List[int]:
n, c, d = x_shape
oc, gc, kd = w_shape
(sx,) = normalize_stride(stride, dim=1)
dilx = dilation
if is_constant(c) and gc * groups != c:
msg = 'Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups)
raise ValueError(msg)
if oc % groups != 0:
msg = 'Conv2d expects out_channels % groups == 0, got out_channels {} and groups {}'.format(oc, groups)
raise ValueError(msg)
p = (d - dilx * (kd - 1) - 1) // sx + 1
return [n, oc, p]
2 changes: 1 addition & 1 deletion python/hidet/graph/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .resolve_variant import ResolveRule, register_resolve_rule, get_resolve_chain
from .graph_patterns import TensorPattern, OperatorPattern, SubgraphRewriteRule, register_rewrite_rule, op_pattern
from .graph_patterns import registered_rewrite_rules
from .graph_patterns import registered_rewrite_rules, clear_registered_rewrite_rules


def optimize(graph: FlowGraph) -> FlowGraph:
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __enter__(self) -> PassContext:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
from ..transforms.graph_patterns import deregister_attn_patterns
from ..transforms.graph_patterns.attn_patterns import deregister_attn_patterns

deregister_attn_patterns()
popped = self._stack.pop()
Expand Down Expand Up @@ -166,7 +166,7 @@ def set_use_attention(self, flag=False) -> PassContext:
if cc < (7, 5):
return self

from ..transforms.graph_patterns import register_attn_patterns, deregister_attn_patterns
from ..transforms.graph_patterns.attn_patterns import register_attn_patterns, deregister_attn_patterns

self.configs['use_attention'] = flag
if flag:
Expand Down
6 changes: 1 addition & 5 deletions python/hidet/graph/transforms/graph_patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@
# limitations under the License.
from .base import TensorPattern, OperatorPattern, SubgraphRewriteRule, MatchDict, Usage, graph_pattern_match
from .base import register_rewrite_rule, op_pattern, registered_rewrite_rules, deregister_rewrite_rule
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns
from .base import clear_registered_rewrite_rules
20 changes: 16 additions & 4 deletions python/hidet/graph/transforms/graph_patterns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,19 @@ def graph_pattern_match(pattern: TensorPattern, target: Tensor, usage: Usage) ->
return None


registered_rewrite_rules: List[SubgraphRewriteRule] = []
_registered_rewrite_rules: List[SubgraphRewriteRule] = []


def registered_rewrite_rules():
# pylint: disable=unused-import

from . import register_all_patterns # register on demand

return list(_registered_rewrite_rules)


def clear_registered_rewrite_rules():
_registered_rewrite_rules.clear()


def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteRule]]):
Expand All @@ -300,10 +312,10 @@ def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteR
should be an instance of SubgraphRewriteRule.
"""
if isinstance(rule, SubgraphRewriteRule):
registered_rewrite_rules.append(rule)
_registered_rewrite_rules.append(rule)
return None
elif issubclass(rule, SubgraphRewriteRule):
registered_rewrite_rules.append(rule())
_registered_rewrite_rules.append(rule())
return rule
else:
raise TypeError('rule should be a SubgraphRewriteRule or a subclass of SubgraphRewriteRule')
Expand All @@ -319,7 +331,7 @@ def deregister_rewrite_rule(rule: SubgraphRewriteRule):
The rule to be deregistered.
"""
if isinstance(rule, SubgraphRewriteRule):
registered_rewrite_rules.remove(rule)
_registered_rewrite_rules.remove(rule)
return None
else:
raise TypeError('rule should be a SubgraphRewriteRule')
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# pylint: disable=unused-import
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns
2 changes: 1 addition & 1 deletion python/hidet/graph/transforms/subgraph_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SubgraphRewritePass(GraphPass):
def process_graph(self, graph: FlowGraph) -> FlowGraph:
graph = graph_utils.functors.clone(graph)
for _ in range(self.max_num_transforms):
updated, graph = self.try_transform(graph, registered_rewrite_rules)
updated, graph = self.try_transform(graph, registered_rewrite_rules())
if not updated:
graph.update_nodes()
return graph
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate(model, text, num_hidden_layers, num_heads, head_dim, device, tokens

@pytest.mark.parametrize('device,opt', [('cpu', False), ('cpu', True), ('cuda', False), ('cuda', True)])
def test_gpt2(device: str, opt: bool):
gpt2_module = hidet.testing.models.gpt2.model()
gpt2_module = hidet.testing.models.gpt2.model(disable_cache=True)

if device == 'cuda':
gpt2_module.cuda()
Expand Down