Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hidet script] Add hidet.lang.types submodule #340

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/hidet/apps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
11 changes: 11 additions & 0 deletions python/hidet/apps/compile_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .compilation import remote_build
from .core import init_api
11 changes: 11 additions & 0 deletions python/hidet/apps/compile_server/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
from .core import api_url

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/apps/compile_server/compilation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zipfile
import shutil
import tempfile
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/apps/compile_server/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import hidet

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/apps/compile_server/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
from .core import access_token, api_url

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/cuda/capability.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from dataclasses import dataclass

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/graph/ops/conv1d/resolve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
Expand Down
8 changes: 4 additions & 4 deletions python/hidet/graph/ops/matmul/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hidet.ir import IRModule
from hidet.ir.compute import reduce
from hidet.ir.expr import is_constant, cast
from hidet.ir.layout import StridesLayout, data_layout, row_major, column_major, local_layout
from hidet.ir.layout import StridesLayout, strided_layout, row_major, column_major, local_layout
from hidet.ir.type import data_type, TensorType, DataType, void_p
from hidet.lang import i32, spatial, repeat, register_tensor, shared_tensor, attrs, grid, tensor_pointer
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
Expand Down Expand Up @@ -437,9 +437,9 @@ def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):
task_shape=[block_k, block_n], num_workers=num_threads, ranks=[0, 1]
)

smem_a_layout = data_layout([2, block_m, block_k], ranks=[0, 1, 2])
smem_b_layout = data_layout([2, block_k, block_n], ranks=[0, 1, 2])
smem_c_layout = data_layout([block_m, block_n], ranks=[0, 1])
smem_a_layout = strided_layout([2, block_m, block_k], ranks=[0, 1, 2])
smem_b_layout = strided_layout([2, block_k, block_n], ranks=[0, 1, 2])
smem_c_layout = strided_layout([block_m, block_n], ranks=[0, 1])
regs_a_layout = row_major(2, mma_count_m, mma_config.a_elements)
regs_b_layout = row_major(2, mma_count_n, mma_config.b_elements)
regs_c_layout = row_major(mma_count_m, mma_count_n, mma_config.c_elements)
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/graph/transforms/graph_patterns/quant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from ..base import SubgraphRewriteRule

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=unused-import
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/graph/transforms/selective_quantize.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

from hidet.graph.flow_graph import FlowGraph
Expand Down
7 changes: 6 additions & 1 deletion python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel, useless-parent-delegation, redefined-outer-name, redefined-builtin
# pylint: disable=useless-super-delegation
# pylint: disable=useless-super-delegation, protected-access
from __future__ import annotations
from typing import Optional, Union, Sequence, Tuple, Dict, Type, Callable
import string
Expand Down Expand Up @@ -815,6 +815,11 @@ def logical_not(a: Union[Expr, PyScalar]):
return Expr._unary(LogicalNot, a)


def bitwise_not(a: Union[Expr, PyScalar]):
a = convert(a)
return Expr._unary(BitwiseNot, a)


def equal(a: Union[Expr, PyScalar], b: Union[Expr, PyScalar]):
a = convert(a)
b = convert(b)
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def local_layout(*shape: Int):
return LocalLayout(shape)


def data_layout(shape: Sequence[Int], ranks: Optional[List[int]] = None):
def strided_layout(shape: Sequence[Int], ranks: Optional[List[int]] = None):
if ranks is None:
ranks = list(range(len(shape)))
return StridesLayout.from_shape(shape, ranks)
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/library/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import cuda
11 changes: 11 additions & 0 deletions python/hidet/ir/library/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .matmul import matmul_simt
11 changes: 11 additions & 0 deletions python/hidet/ir/library/cuda/matmul/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .simt import matmul_simt
17 changes: 14 additions & 3 deletions python/hidet/ir/library/cuda/matmul/simt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import hidet
from hidet.ir.type import DataType, TensorType, tensor_type, tensor_pointer_type
from hidet.ir.dtypes import vectorize, i32
from hidet.ir.expr import Expr, is_true
from hidet.ir.layout import row_major, data_layout, local_layout
from hidet.ir.layout import row_major, strided_layout, local_layout
from hidet.ir.library import tune
from hidet.ir.library.utils import get_tensor_type
from hidet.ir.primitives.runtime import request_cuda_workspace
Expand Down Expand Up @@ -157,8 +168,8 @@ def matmul_simt(
# prepare data layout
tune.check(block_k % lanes == 0)
block_k_vectors = block_k // lanes
smem_a_layout = data_layout([2, block_k_vectors, block_m + 1])
smem_b_layout = data_layout([2, block_k_vectors, block_n + 1])
smem_a_layout = strided_layout([2, block_k_vectors, block_m + 1])
smem_b_layout = strided_layout([2, block_k_vectors, block_n + 1])
regs_a_layout = ( # 2 x block_m
row_major(2, 1)
.local(1, block_warps[0])
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/library/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.tools import infer_type
from hidet.ir.type import TensorType, TensorPointerType
from hidet.ir.expr import Expr
Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/primitives/cuda/errchk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.stmt import BlackBoxStmt


Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/target.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List, Dict

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/utils/broadcast_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, List
from hidet.ir.expr import Expr, Int, is_constant, if_then_else

Expand Down
2 changes: 1 addition & 1 deletion python/hidet/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hidet.lang.constructs import meta

from hidet.ir.type import BaseType, DataType, TensorType, PointerType, VoidType, ReferenceType, void_p, data_type
from hidet.ir.expr import Expr, Var, cast, view, address, Dereference
from hidet.ir.expr import Expr, Var, cast, view, address, Dereference, bitwise_not
from hidet.ir.mapping import row_spatial, row_repeat, col_repeat, col_spatial, TaskMapping, auto_map
from hidet.ir.layout import DataLayout
from hidet.ir.primitives import printf
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/lang/constructs/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def grid(*dim_extents, attrs: Optional[str] = None, bind_tuple=False):

Parameters
----------
dim_extents: Sequence[Expr or int or list or tuple or str]
dim_extents: Expr or int or list or tuple or str
The length of each dimension. The last one can be the attrs.

attrs: Optional[str]
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/lang/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=unused-import
from hidet.ir.layout import row_major, column_major, local_layout, DataLayout
from hidet.ir.layout import row_major, column_major, local_layout, strided_layout, DataLayout
19 changes: 19 additions & 0 deletions python/hidet/lang/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=unused-import
from hidet.ir.dtypes import i8, i16, i32, i64, u8, u16, u32, u64, f16, f32, f64, bf16, tf32
from hidet.ir.dtypes import int8, int16, int32, int64, uint8, uint32, uint64, float16, float32, float64, bfloat16
from hidet.ir.dtypes import tfloat32

from hidet.ir.type import void_p, void, byte_p

from hidet.lang.constructs.declare import register_tensor, shared_tensor, tensor_pointer, tensor, DeclareScope