-
Notifications
You must be signed in to change notification settings - Fork 552
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Adding a dedicated sum kernel when on a strided dimension
- adding a dedicated benchmark, better unit test - moving to a tile based approach to better handle big buffers - trying to find better scheduling defaults
- Loading branch information
1 parent
e75ec39
commit 00aebbc
Showing
12 changed files
with
241 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,4 +51,4 @@ examples/data | |
|
||
# Hydra default output dir | ||
multirun | ||
outputs | ||
outputs |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, List | ||
|
||
import torch | ||
import triton | ||
|
||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print | ||
from xformers.triton.sum_strided import sum_2d_dim_0 | ||
|
||
SHAPES = [ | ||
(128, 128), | ||
(384, 128), | ||
(784, 512), | ||
(1024, 768), | ||
(2048, 1024), | ||
(4096, 4096), | ||
] | ||
|
||
|
||
def to_gbs(a, ms): | ||
# Read the full array, write the non-reduced dimension | ||
return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3) | ||
|
||
|
||
def bench_functions( | ||
test_cases: List[TestCase], shapes, metric_transform, unit, title="" | ||
): | ||
device = torch.device("cuda") | ||
|
||
for dtype in [torch.float16, torch.float32]: | ||
results: Dict[str, Any] = {} | ||
|
||
for M, N in shapes: | ||
a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True) | ||
|
||
for testcase in test_cases: | ||
time = triton.testing.do_bench(lambda: testcase.function(a))[0] | ||
|
||
metric = metric_transform(a, time) | ||
|
||
key = f"M={M}, N={N}" | ||
if key not in results: | ||
results[key] = {} | ||
|
||
results[key][testcase.name] = f"{metric:.1f}" | ||
|
||
_type = " fp16" if dtype == torch.float16 else " fp32" | ||
|
||
pretty_print( | ||
results, | ||
title=" ------------- Type: {} ------------- ".format(_type), | ||
units=unit, | ||
) | ||
|
||
pretty_plot(results, title + _type, unit, dash_key="pytorch") | ||
|
||
|
||
bench_functions( | ||
[ | ||
TestCase(lambda x: torch.sum(x, dim=0), "pytorch"), | ||
TestCase(sum_2d_dim_0, "triton"), | ||
], | ||
SHAPES, | ||
to_gbs, | ||
"GB/s", | ||
"Strided_sum", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
|
||
# fmt: off | ||
@triton.jit | ||
def k_sum_0( | ||
Y, X, | ||
stride_xm, | ||
M, N, | ||
is_fp16, | ||
**meta, | ||
): | ||
# fmt: om | ||
|
||
""" | ||
Sum a 2d tensor over the first (strided) dimension. | ||
This extracts some speed through a parallel sum across the second dimension | ||
""" | ||
BLOCK_M = meta["BLOCK_M"] | ||
BLOCK_N = meta["BLOCK_N"] | ||
|
||
# partial row indices. We'll reduce over this dimension | ||
m = tl.arange(0, BLOCK_M) | ||
|
||
# To get some extra parallelization, we handle several columns in the same thread block | ||
rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
||
# the memory address of all the elements that we want to load can be computed as follows | ||
x_ptrs = X + m[:, None] * stride_xm + rn[None, :] | ||
x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32) | ||
|
||
tiles = M // BLOCK_M | ||
if M % BLOCK_M > 0: | ||
tiles += 1 | ||
|
||
col_mask = (rn[None, :] < N) | ||
|
||
for _ in range(tiles): | ||
# load input data; pad out-of-bounds elements with 0 | ||
# NOTE: make sure to accumulate in fp32 to prevent a trivial overflow | ||
mask = (m[:, None] < M) & col_mask | ||
x = tl.load(x_ptrs, mask=mask, other=0.0) | ||
x_sum += tl.sum(x, 0) | ||
|
||
# move the load pointer | ||
x_ptrs += BLOCK_M * stride_xm | ||
m += BLOCK_M # update the mask check | ||
|
||
tl.store(Y + rn, x_sum, mask=rn < N) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torch | ||
import triton | ||
|
||
from xformers.triton.k_sum import k_sum_0 | ||
|
||
|
||
def sum_2d_dim_0(x: torch.Tensor): | ||
""" | ||
Sum a 2D tensor across the first dimension | ||
""" | ||
|
||
out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype) | ||
|
||
assert ( | ||
x.ndim == 2 | ||
), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0" | ||
M, N = x.shape | ||
|
||
# This kernel is not competitive for these sizes | ||
if M > 2048 or M < 8: | ||
return x.sum(dim=0) | ||
|
||
assert ( | ||
M >= 4 | ||
), "This is a very specific kernel, requires the reduction dimension to be bigger than 4" | ||
|
||
assert x.stride(1) == 1, ( | ||
"We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n" | ||
" You would probably be better served with torch.sum()" | ||
) | ||
|
||
BLOCK_M = min(triton.next_power_of_2(M), 2048) | ||
BLOCK_N = 32 | ||
if BLOCK_M > 256: | ||
BLOCK_N = 16 | ||
if BLOCK_M > 1024: | ||
BLOCK_N = 8 | ||
|
||
def grid(meta): | ||
return (triton.cdiv(N, meta["BLOCK_N"]),) | ||
|
||
# fmt: off | ||
k_sum_0[grid]( | ||
out, x, | ||
x.stride(0), | ||
M, N, | ||
x.dtype == torch.float16, | ||
BLOCK_M=BLOCK_M, | ||
BLOCK_N=BLOCK_N, | ||
num_stages=4, | ||
) | ||
# fmt: on | ||
|
||
return out |