Skip to content

Commit

Permalink
SwiGLU optimized fw/bw
Browse files Browse the repository at this point in the history
ghstack-source-id: 7b874c69561bf1756e95ccfad9407e4ea9d18e85
Pull Request resolved: #490
  • Loading branch information
danthe3rd committed Oct 25, 2022
1 parent 5227f2f commit 59249e2
Show file tree
Hide file tree
Showing 20 changed files with 4,072 additions and 36 deletions.
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
23 changes: 14 additions & 9 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ def generate_test_shapes():
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.float16]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("pack_weights", [True, False])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -105,8 +108,11 @@ def test_forward_backward(
device,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
Expand All @@ -124,8 +130,11 @@ def test_forward_backward(
inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp
op = xsw.SwiGLUFusedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -150,11 +159,7 @@ def test_forward_backward(
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
21 changes: 17 additions & 4 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils import benchmark_main_helper

import xformers.ops.swiglu as xsw
from xformers.ops import unbind as xunbind

min_run_time = 0.5
device = torch.device("cuda")
Expand All @@ -22,10 +23,13 @@
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
# Some smaller shapes as well
(4728, 1536, 1024),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
OP = xsw.SwiGLUFusedOp


def product_dict(**kwargs):
Expand All @@ -38,7 +42,7 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.half],
)
)

Expand All @@ -61,11 +65,16 @@ def benchmark_swiglu(shape, dtype):
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast

params = module._ordered_params_for_op()
# w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape])
# params[0], params[2] = w1w2.unbind(dim=0)

yield benchmark.Timer(
stmt="fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
Expand Down Expand Up @@ -103,7 +112,11 @@ def benchmark_swiglu_bw(shape, dtype):
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]).detach()
w1w2.requires_grad_()
params[0], params[2] = xunbind(w1w2, dim=0)
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"

namespace ge_spmm {

Expand Down
35 changes: 35 additions & 0 deletions xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.



cutlass_example_add_executable(
43_dual_gemm
dual_gemm.cu
)

0 comments on commit 59249e2

Please sign in to comment.