Skip to content

Commit

Permalink
[Graph] Translate softmax and reduce to hidet script (#242)
Browse files Browse the repository at this point in the history
Convert IRBuilder based schedule templates for softmax and reduce operator to hidet script one.

---------

Co-authored-by: Allan Lin <allan.lin@centml.ai>
Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
  • Loading branch information
3 people committed Jun 2, 2023
1 parent a1706b2 commit 59e2eae
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 275 deletions.
162 changes: 147 additions & 15 deletions python/hidet/graph/ops/definitions/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Optional, Sequence
import builtins

from hidet.ir.expr import is_constant
from ..arithmetic import square, sqrt
from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType
from ..utils import compute, reduce, input_like, normalize_dim, arg_reduce
Expand Down Expand Up @@ -70,23 +68,157 @@ def reduce_fcompute(*reduce_indices):
},
)

def implement_cuda(self, working_dir: str) -> IRModule:
# pylint: disable=import-outside-toplevel
from ...schedules import cuda_schedule_reduce_by_default, cuda_schedule_reduce_by_warp_reduce

if self.inputs[0].type.dtype.name == 'float64':
return NotImplemented # use auto-scheduler

if not builtins.all(is_constant(dim) for dim in self.inputs[0].shape):
return NotImplemented
def allow_epilogue(self) -> bool:
rank = len(self.inputs[0].shape)
if rank - 1 in self.dims: # pylint: disable=simplifiable-if-statement
# use self.cuda_schedule_reduce_by_warp
return True
else:
# use self.cuda_schedule_reduce_by_default
return False

def implement_cuda(self, working_dir: str) -> IRModule:
rank = len(self.inputs[0].shape)
if rank - 1 in self.dims:
# reduce over last dimension
return cuda_schedule_reduce_by_warp_reduce(self)
return self.cuda_schedule_reduce_by_warp()
else:
return self.cuda_schedule_reduce_by_default()

def cuda_schedule_reduce_by_warp(self) -> IRModule:
import hidet
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync
from hidet.ir.compute import ReduceOperation
from hidet.ir.type import data_type
from hidet.ir.layout import DataLayout
from hidet.lang import spatial, repeat, attrs, cast
from hidet.lang.cuda import blockIdx, threadIdx

row_major = DataLayout.row_major

warp_size = 32
block_size = warp_size
x, y = self.inputs[0], self.outputs[0]
xdtype = x.type.dtype
shape: List[int] = list(x.const_shape)
dims = self.dims
if self.keep_dim:
remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)]
else:
# last dimension has not been reduced
return cuda_schedule_reduce_by_default(self)
remain_shape = [v for i, v in enumerate(shape) if i not in dims]
reduce_shape = [shape[i] for i in dims]
reduce_extent = hidet.utils.prod(reduce_shape)

remain_layout = spatial(*remain_shape)
layout = row_major(shape)

spatial_shape = []
repeat_shape = []
for i in range(len(shape)):
if i == len(shape) - 1:
spatial_shape.append(warp_size)
repeat_shape.append((shape[i] + warp_size - 1) // warp_size) # num warps per row
elif i in dims:
spatial_shape.append(1)
repeat_shape.append(shape[i])
else:
spatial_shape.append(shape[i])
repeat_shape.append(1)
task_layout = repeat(*repeat_shape) * spatial(*spatial_shape)
grid_size = remain_layout.num_workers
accumulate_dtype = self.attrs['accumulate_dtype']
reduce_type = self.attrs['reduce_type']
ro = ReduceOperation.from_name(reduce_type)

with hidet.script_module() as module:

@hidet.script
def reduce_kernel(x: xdtype[x.const_shape], y: xdtype[y.const_shape]):
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1

rv = ro.initial_value(data_type(accumulate_dtype))
for indices in task_layout.on(threadIdx.x + blockIdx.x * block_size):
if layout.within_bound(indices):
k = x[indices]
rv = ro.combine(rv, cast(k, accumulate_dtype))
# Warp reduce by shuffle down
mask = active_mask()
rv = ro.combine(rv, shfl_down_sync(mask, rv, 16, 32))
rv = ro.combine(rv, shfl_down_sync(mask, rv, 8, 32))
rv = ro.combine(rv, shfl_down_sync(mask, rv, 4, 32))
rv = ro.combine(rv, shfl_down_sync(mask, rv, 2, 32))
rv = ro.combine(rv, shfl_down_sync(mask, rv, 1, 32))
rv = shfl_sync(mask, rv, 0, 32)
rv = ro.finalize(acc=rv, size=reduce_extent)

if threadIdx.x == 0:
for indices in remain_layout.on(blockIdx.x):
y[indices] = cast(rv, xdtype)

ir_module = module.ir_module()
return ir_module

def cuda_schedule_reduce_by_default(self) -> IRModule:
import hidet
from hidet.ir.compute import ReduceOperation
from hidet.ir.type import data_type
from hidet.lang import spatial, repeat, attrs
from hidet.lang.cuda import blockIdx, threadIdx, register_tensor

x, y = self.inputs[0], self.outputs[0]
dims = self.dims
shape: List[int] = list(x.const_shape)
xdtype = x.type.dtype

if self.keep_dim:
remain_shape = [v if i not in dims else 1 for i, v in enumerate(shape)]
else:
remain_shape = [v for i, v in enumerate(shape) if i not in dims]

remain_extent = hidet.utils.prod(remain_shape)
reduce_shape = [shape[i] for i in dims]
reduce_extent = hidet.utils.prod(reduce_shape)
block_size = 256 if 256 < remain_extent else remain_extent
remain_layout = spatial(*remain_shape)

spatial_shape = []
repeat_shape = []
for i in range(len(shape)):
if i in dims:
spatial_shape.append(1)
repeat_shape.append(shape[i])
else:
spatial_shape.append(shape[i])
repeat_shape.append(1)
task_layout = repeat(*repeat_shape) * spatial(*spatial_shape)

grid_size = (remain_layout.num_workers + block_size - 1) // block_size
accumulate_dtype = self.attrs['accumulate_dtype']
reduce_type = self.attrs['reduce_type']
ro = ReduceOperation.from_name(reduce_type)

with hidet.script_module() as module:

@hidet.script
def reduce_kernel(x: xdtype[x.const_shape], y: xdtype[y.const_shape]):
# Each 256-thread ThreadBlock handles 512 columns
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1

rv = register_tensor(accumulate_dtype, [1])
rv[0] = ro.initial_value(data_type(accumulate_dtype))

if threadIdx.x + blockIdx.x * block_size < remain_extent:
for indices in task_layout.on(threadIdx.x + blockIdx.x * block_size):
rv[0] = ro.combine(rv[0], x[indices])
rv[0] = ro.finalize(acc=rv[0], size=reduce_extent)
for indices in remain_layout.on(threadIdx.x + blockIdx.x * block_size):
y[indices] = rv[0]

ir_module = module.ir_module()
return ir_module


class ArgReduceTask(Task):
Expand Down
66 changes: 63 additions & 3 deletions python/hidet/graph/ops/definitions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,69 @@ def __init__(self, x: TensorNode, axis: int):
super().__init__(name='softmax', inputs=[x], outputs=[out])

def implement_cuda(self, working_dir: str) -> IRModule:
from hidet.graph.ops.schedules import softmax_cuda_schedule

if not all(is_constant(dim) for dim in self.inputs[0].shape):
return NotImplemented # use auto-scheduler

return softmax_cuda_schedule(self)
import math
import hidet
from hidet.lang import tensor
from hidet.lang import attrs

from hidet.ir.mapping import TaskMapping

from hidet.lang.cuda import blockIdx, threadIdx
from hidet.graph.ops.schedules.cuda.common import warp_reduce

shape = self.inputs[0].const_shape
axis = self.axis
reduce_extent = shape[axis]
reduced_shape = shape[:axis] + shape[axis + 1 :]
n_reduce = math.prod(reduced_shape)
warp_size = 32
outer_extent = (reduce_extent + warp_size - 1) // warp_size
grid_layout = TaskMapping.row_major(reduced_shape)
other_inds = list(grid_layout.worker2task(blockIdx.x)[0])
xdtype = self.inputs[0].type.dtype

def sum_expr(a, b):
return a + b

with hidet.script_module() as module:

@hidet.script
def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]):
attrs.cuda.block_dim = warp_size
attrs.cuda.grid_dim = n_reduce

temp = tensor('register', xdtype, shape=[outer_extent])

rv = -xdtype.max_value

# compute maximum in the dimension to be softmaxed across
for k in range(outer_extent):
idx = threadIdx.x + k * warp_size
if idx < reduce_extent:
temp[k] = xs[other_inds[:axis] + [idx] + other_inds[axis:]]
rv = prim.max(rv, temp[k])
warp_reduce(rv, prim.max)

# exp
for k in range(outer_extent):
temp[k] = prim.exp(temp[k] - rv)

rv = xdtype.zero
for k in range(outer_extent):
idx = threadIdx.x + k * warp_size
if idx < reduce_extent:
rv += temp[k]
warp_reduce(rv, sum_expr)

for k in range(outer_extent):
idx = threadIdx.x + k * warp_size
if idx < reduce_extent:
ys[other_inds[:axis] + [idx] + other_inds[axis:]] = temp[k] / rv

assert isinstance(softmax_kernel, hidet.ir.Function)
ir_module = module.ir_module()

return ir_module
3 changes: 0 additions & 3 deletions python/hidet/graph/ops/schedules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,4 @@
from . import cpu
from . import cuda

from .cuda.softmax import softmax_cuda_schedule
from .cuda.reduce import cuda_schedule_reduce_by_default, cuda_schedule_reduce_by_warp_reduce

from .common import Schedule
3 changes: 0 additions & 3 deletions python/hidet/graph/ops/schedules/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .softmax import softmax_cuda_schedule
from .reduce import cuda_schedule_reduce_by_default

0 comments on commit 59e2eae

Please sign in to comment.