Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwiGLU optimized fw/bw #490

Merged
merged 36 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
069405e
SwiGLU optimized fw/bw
Oct 24, 2022
4b317c6
Update on "SwiGLU optimized fw/bw"
Oct 24, 2022
11bad90
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
8b2f688
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
e1609de
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
30ca17c
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
eb9c553
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
ed2b7c2
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
e758435
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
3207254
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
dbf6092
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
acdf239
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
bbdc00e
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
5fe54aa
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
44a6fbf
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
d3e3089
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
db5770d
Update on "SwiGLU optimized fw/bw"
Oct 27, 2022
4c2bfdc
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
d2d0187
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
e2d97d2
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
7224112
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
06c1487
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
783a2ff
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
69e299f
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
f6e2ceb
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
538d05c
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
0ab305f
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
c67a0ad
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
a77aeec
Update on "SwiGLU optimized fw/bw"
Oct 31, 2022
4b600bf
Update on "SwiGLU optimized fw/bw"
Oct 31, 2022
dd6a285
Update on "SwiGLU optimized fw/bw"
Nov 3, 2022
d825314
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
e2bfbb2
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
07135b8
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
3490242
Update on "SwiGLU optimized fw/bw"
Nov 7, 2022
a90fe49
Update on "SwiGLU optimized fw/bw"
Nov 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
41 changes: 25 additions & 16 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import random
from contextlib import nullcontext
from typing import Optional

import pytest
Expand All @@ -13,7 +14,7 @@

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
_devices = ["cuda"] if torch.cuda.is_available() else []


def assert_allclose(
Expand Down Expand Up @@ -83,16 +84,19 @@ def generate_test_shapes():
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.bfloat16, torch.float16]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("autocast", [True, False])
@pytest.mark.parametrize("pack_weights", [True, False])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -105,27 +109,37 @@ def test_forward_backward(
device,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
torch.bfloat16: 4.0, # !!
}
BACKWARD_RTOL = {
torch.float: 2e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}

if device == "cpu" and dtype is not torch.float:
pytest.skip("Half not supported on CPU")
if autocast and (device == "cpu" or dtype is not torch.half):
pytest.skip("Autocast only supported for CUDA+Half")
if autocast and pack_weights is False:
pytest.skip("TODO: Autocast only supported with pack_weights=True")

inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp
op = xsw.SwiGLUFusedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -140,21 +154,16 @@ def test_forward_backward(
x.requires_grad_()

# Forward
if autocast:
with torch.autocast("cuda", dtype=dtype):
ref = module(x)
else:
cm = torch.autocast("cuda", dtype=dtype) if autocast else nullcontext()
with cm:
ref = module(x)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)

if ref_f32 is None:
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
63 changes: 42 additions & 21 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@


import itertools
from contextlib import nullcontext
from functools import partial

import torch
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops.swiglu as xsw
from xformers.ops import unbind as xunbind

min_run_time = 0.5
device = torch.device("cuda")
Expand All @@ -22,10 +24,13 @@
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
# Some smaller shapes as well
(4728, 1536, 1024),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
OP = xsw.SwiGLUFusedOp


def product_dict(**kwargs):
Expand All @@ -38,13 +43,22 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.bfloat16, torch.half, "autocast_half"],
)
)

DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}


def benchmark_swiglu(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
Expand All @@ -53,27 +67,25 @@ def benchmark_swiglu(shape, dtype):
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
params = module._ordered_params_for_op()

PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt="fn(x, *args)",
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt="fn(x)",
stmt=f"{PREFIX}fn(x)",
globals={
"x": x,
"fn": module,
Expand All @@ -85,7 +97,12 @@ def benchmark_swiglu(shape, dtype):


def benchmark_swiglu_bw(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
Expand All @@ -95,16 +112,17 @@ def benchmark_swiglu_bw(shape, dtype):
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]).detach()
w1w2.requires_grad_()
params[0], params[2] = xunbind(w1w2, dim=0)
with cm():
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
Expand All @@ -117,10 +135,13 @@ def benchmark_swiglu_bw(shape, dtype):
)
del out

with cm():
out = module(x)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": module(x),
"out": out,
"grad": grad,
},
label="swiglu_bw",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved

namespace ge_spmm {

Expand Down
35 changes: 35 additions & 0 deletions xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.



cutlass_example_add_executable(
43_dual_gemm
dual_gemm.cu
)