Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] minor enhancements to attention and register a few functions #345

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,28 @@ def torch_mean(
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True)
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.sum(..., out=...)")
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=dim, keep_dim=keepdim)
return output


Comment on lines +984 to +1005
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two torch_sums here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the convention for the mean method above. Not entirely sure either as I thought python does not support overloading. Perhaps @yaoyaoding knows?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python itself does not support function overloading. We used the inspect module to support overloading in hidet. This is needed because some pytorch function/methods have multiple signatures.
The implementation can be found at here and here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Member

@hjjq hjjq Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. But does torch.Tensor.sum and torch.sum have the same signature? If they do, then no need for overloading? https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
Also it doesn't seem that either of them has the out argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it doesn't seem that either of them has the out argument.
right, let me fix this

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions:
image
and its interesting that the code has out parameter but the documentation does not have.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

I see.
Also, keepdim in the first case (L989) should default to False?
Lastly, torch.Tensor.sum seems to have a slightly different signature, where dim has a default value (whereas torch.sum doesn't have a default, making dim mandatory). So in the case below, it will resolve to the first case because of missing dim, and possibly produce wrong results?

a = torch.randn(...)
b = a.sum(keepdim=True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions: image and its interesting that the code has out parameter but the documentation does not have.

image
Interestingly, my pytorch code doesn't have out. Maybe we have different version/build of torch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird, I see two places where Torch generates

@overload
def xxx(args, )

One under ./_C/__init__.pyi (does not have "out"), another under ./_C/_VariableFunctions.pyi (has "out"). However, both are just generated signatures that don't really represent the actual implementation. I think they are there to make the IDEs work.

The actual implementation should be at aten/src/ATen/native/native_functions.yaml, which has "out".

Even if the actual op does not support "out", having an optional out argument should not break the inspect.Signature.bind function. so we should still be fine, and it would be better to include "out" here

@register_function(torch.cumsum)
def torch_cumsum(x: Tensor, dim, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
if self.is_symbolic() and instantiate_device(device_from_torch(arg)) != self.device:
raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.')
device = arg
elif isinstance(arg, Tensor):
dtype = arg.dtype
if self.is_symbolic() and arg.device != self.device:
raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.')
device = arg.device
else:
raise ValueError(f'Unsupported argument type: {type(arg)}')

Expand Down Expand Up @@ -222,7 +227,10 @@ def tensor_type(self: Tensor, dtype: Union[str, torch.dtype], non_blocking: bool

@register_method(torch.Tensor.expand)
def tensor_expand(self: Tensor, *sizes: int) -> Tensor:
sizes: List[int] = list(sizes)
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = sizes[0]
else:
sizes: List[int] = list(sizes)
assert len(sizes) >= len(self.shape)
for i in range(len(sizes)):
if sizes[i] == -1:
Expand Down
20 changes: 10 additions & 10 deletions python/hidet/graph/ops/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type):
def copy_k_g2s_sm80(
k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32
):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j_seg in k_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -411,7 +411,7 @@ def copy_k_g2s_sm80(

@hidet.script
def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j_seg in v_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -421,7 +421,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o

@hidet.script
def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j_seg in q_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -433,7 +433,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs
def copy_k_g2s_sm75(
k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32
):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j in k_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]:
Expand All @@ -444,7 +444,7 @@ def copy_k_g2s_sm75(

@hidet.script
def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j in v_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]:
Expand All @@ -455,7 +455,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o

@hidet.script
def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j in q_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]:
Expand Down Expand Up @@ -488,7 +488,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i:
@hidet.script
def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_o = o[o_head_index][offset_i:, :]
for k_round in range(warp_count_k):
for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id):
Expand Down Expand Up @@ -652,12 +652,12 @@ def attn_kernel(
v: f16[v_head + [n_kv_size, d_size]],
o: f16[o_head + [n_size, d_size]],
):
attrs.cuda.grid_dim = (i_split, bs)
attrs.cuda.grid_dim = i_split * bs
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1
attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes

offset_i = blockIdx.x * i_rows_per_tb
offset_i = (blockIdx.x % i_split) * i_rows_per_tb

smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout)
smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout)
Expand Down Expand Up @@ -702,7 +702,7 @@ def attn_kernel(

j_tiles = cdiv(n_kv_size, block_j)
if is_causal:
j_tiles = min(cdiv((blockIdx.x + 1) * block_i, block_j), j_tiles)
j_tiles = min(cdiv(((blockIdx.x % i_split) + 1) * block_i, block_j), j_tiles)
for j in range(j_tiles):
offset_j = block_j * j

Expand Down
20 changes: 10 additions & 10 deletions python/hidet/graph/ops/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type):
def copy_k_g2s_sm80(
k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32
):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j_seg in k_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -439,7 +439,7 @@ def copy_k_g2s_sm80(

@hidet.script
def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j_seg in v_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -449,7 +449,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o

@hidet.script
def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j_seg in q_g2s_layout.on(threadIdx.x):
j = j_seg * 8
Expand All @@ -461,7 +461,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs
def copy_k_g2s_sm75(
k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32
):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:]
for i, j in k_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]:
Expand All @@ -472,7 +472,7 @@ def copy_k_g2s_sm75(

@hidet.script
def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :]
for i, j in v_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]:
Expand All @@ -483,7 +483,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o

@hidet.script
def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32):
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :]
for i, j in q_g2s_layout_sm75.on(threadIdx.x):
if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]:
Expand Down Expand Up @@ -516,7 +516,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i:
@hidet.script
def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
o_head_index = spatial(*o_head).map(blockIdx.y)
o_head_index = spatial(*o_head).map(blockIdx.x // i_split)
gmem_o = o[o_head_index][offset_i:, :]
for k_round in range(warp_count_k):
for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id):
Expand Down Expand Up @@ -681,12 +681,12 @@ def attn_kernel(
mask: f16[mask_shape],
o: f16[o_head + [n_size, d_size]],
):
attrs.cuda.grid_dim = (i_split, bs)
attrs.cuda.grid_dim = i_split * bs
attrs.cuda.block_dim = block_size
attrs.cuda.min_blocks = 1
attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes

offset_i = blockIdx.x * i_rows_per_tb
offset_i = (blockIdx.x % i_split) * i_rows_per_tb

smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout)
smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout)
Expand Down Expand Up @@ -773,7 +773,7 @@ def attn_kernel(
copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j)

# Apply Masking
qk_head_index = list(spatial(*qk_head).map(blockIdx.y))
qk_head_index = list(spatial(*qk_head).map(blockIdx.x // i_split))
for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
wi, wj, wk = spatial(warp_count_m, warp_count_n, warp_count_k).map(warp_id)
Expand Down