Skip to content

Commit

Permalink
SwiGLU optimized fw/bw
Browse files Browse the repository at this point in the history
ghstack-source-id: abc12d1ec3cabbec2ebd5c2fffb72167609f3d85
Pull Request resolved: #490
  • Loading branch information
danthe3rd committed Oct 25, 2022
1 parent 5227f2f commit 6d90037
Show file tree
Hide file tree
Showing 20 changed files with 4,136 additions and 60 deletions.
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
41 changes: 25 additions & 16 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import random
from contextlib import nullcontext
from typing import Optional

import pytest
Expand All @@ -13,7 +14,7 @@

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devices = ["cuda"] if torch.cuda.is_available() else []


def assert_allclose(
Expand Down Expand Up @@ -83,16 +84,19 @@ def generate_test_shapes():
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.bfloat16, torch.float16]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("autocast", [True, False])
@pytest.mark.parametrize("pack_weights", [True, False])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -105,27 +109,37 @@ def test_forward_backward(
device,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
torch.bfloat16: 4.0, # !!
}
BACKWARD_RTOL = {
torch.float: 2e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}

if device == "cpu" and dtype is not torch.float:
pytest.skip("Half not supported on CPU")
if autocast and (device == "cpu" or dtype is not torch.half):
pytest.skip("Autocast only supported for CUDA+Half")
if autocast and pack_weights is False:
pytest.skip("TODO: Autocast only supported with pack_weights=True")

inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp
op = xsw.SwiGLUFusedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -140,21 +154,16 @@ def test_forward_backward(
x.requires_grad_()

# Forward
if autocast:
with torch.autocast("cuda", dtype=dtype):
ref = module(x)
else:
cm = torch.autocast("cuda", dtype=dtype) if autocast else nullcontext()
with cm:
ref = module(x)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)

if ref_f32 is None:
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
63 changes: 42 additions & 21 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@


import itertools
from contextlib import nullcontext
from functools import partial

import torch
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops.swiglu as xsw
from xformers.ops import unbind as xunbind

min_run_time = 0.5
device = torch.device("cuda")
Expand All @@ -22,10 +24,13 @@
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
# Some smaller shapes as well
(4728, 1536, 1024),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
OP = xsw.SwiGLUFusedOp


def product_dict(**kwargs):
Expand All @@ -38,13 +43,22 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.bfloat16, torch.half, "autocast_half"],
)
)

DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}


def benchmark_swiglu(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
Expand All @@ -53,27 +67,25 @@ def benchmark_swiglu(shape, dtype):
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
params = module._ordered_params_for_op()

PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt="fn(x, *args)",
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt="fn(x)",
stmt=f"{PREFIX}fn(x)",
globals={
"x": x,
"fn": module,
Expand All @@ -85,7 +97,12 @@ def benchmark_swiglu(shape, dtype):


def benchmark_swiglu_bw(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
Expand All @@ -95,16 +112,17 @@ def benchmark_swiglu_bw(shape, dtype):
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]).detach()
w1w2.requires_grad_()
params[0], params[2] = xunbind(w1w2, dim=0)
with cm():
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
Expand All @@ -117,10 +135,13 @@ def benchmark_swiglu_bw(shape, dtype):
)
del out

with cm():
out = module(x)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": module(x),
"out": out,
"grad": grad,
},
label="swiglu_bw",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"

namespace ge_spmm {

Expand Down
35 changes: 35 additions & 0 deletions xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.



cutlass_example_add_executable(
43_dual_gemm
dual_gemm.cu
)

0 comments on commit 6d90037

Please sign in to comment.