Skip to content

Commit

Permalink
SwiGLU optimized fw/bw
Browse files Browse the repository at this point in the history
ghstack-source-id: 7998ff3210011362be7c379666655e9bc5078dde
Pull Request resolved: #490
  • Loading branch information
danthe3rd committed Nov 10, 2022
1 parent 0428e12 commit 51c9861
Show file tree
Hide file tree
Showing 16 changed files with 1,057 additions and 99 deletions.
3 changes: 3 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ cpu_py38: &cpu_py38
docker:
- image: cimg/python:3.8
resource_class: large
environment:
# We're a bit short on RAM
MAX_JOBS: "4"

gpu_cu114: &gpu_cu114
environment:
Expand Down
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
73 changes: 48 additions & 25 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.

import random
from typing import Optional
from contextlib import nullcontext
from typing import Optional, Sequence

import pytest
import torch
Expand All @@ -13,7 +14,13 @@

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
if torch.cuda.is_available():
_devices = ["cuda"]
_is_sm80 = torch.cuda.get_device_capability(_devices[0])[0] >= 8
else:
_devices = []
_is_sm80 = False
sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80")


def assert_allclose(
Expand Down Expand Up @@ -76,23 +83,32 @@ def generate_test_shapes():
shapes = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# GPT-3 (small)
(2048, 2048, 5632),
# Chinchilla
(2048, 8192, 22016),
]
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.bfloat16, torch.float16]
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("autocast", [False, True])
@pytest.mark.parametrize("pack_weights", [False, True])
@pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -103,29 +119,41 @@ def generate_test_shapes():
def test_forward_backward(
shape,
device,
op,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
torch.bfloat16: 4.0, # !!
}
BACKWARD_RTOL = {
torch.float: 2e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}

if device == "cpu" and dtype is not torch.float:
pytest.skip("Half not supported on CPU")
if autocast and (device == "cpu" or dtype is not torch.half):
pytest.skip("Autocast only supported for CUDA+Half")
if not op.supports(
xsw.SwiGLUOpDispatch(
device=device,
dtype=dtype,
dtype_autocast_gpu=dtype if autocast and device == "cuda" else None,
packed_weights=pack_weights,
)
):
pytest.skip("Not supported by operator")

inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -140,21 +168,16 @@ def test_forward_backward(
x.requires_grad_()

# Forward
if autocast:
with torch.autocast("cuda", dtype=dtype):
ref = module(x)
else:
cm = torch.autocast("cuda", dtype=dtype) if autocast else nullcontext()
with cm:
ref = module(x)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)

if ref_f32 is None:
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 27 files
+21 −13 examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu
+2 −2 examples/40_cutlass_py/README.md
+0 −5 examples/41_multi_head_attention/fused_multihead_attention.cu
+1 −1 examples/43_dual_gemm/CMakeLists.txt
+59 −37 examples/43_dual_gemm/device/dual_gemm.h
+92 −89 examples/43_dual_gemm/dual_gemm.cu
+101 −85 examples/43_dual_gemm/dual_gemm_run.h
+73 −155 examples/43_dual_gemm/kernel/dual_gemm.h
+0 −0 examples/43_dual_gemm/test_run.h
+150 −0 examples/43_dual_gemm/thread/left_silu_and_mul.h
+52 −27 examples/43_dual_gemm/threadblock/dual_epilogue.h
+0 −0 examples/43_dual_gemm/threadblock/dual_mma_base.h
+0 −0 examples/43_dual_gemm/threadblock/dual_mma_multistage.h
+0 −275 examples/50_dual_gemm/reference/device/tensor_scale_bias.h
+0 −648 examples/50_dual_gemm/threadblock/dual_mma_multistage2.h
+1 −1 examples/CMakeLists.txt
+1 −0 include/cutlass/epilogue/thread/activation.h
+2 −3 include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
+414 −0 include/cutlass/gemm/device/gemm_with_k_reduction.h
+1 −1 include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h
+1 −0 include/cutlass/gemm/kernel/gemm_with_k_reduction.h
+1 −1 include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h
+1 −0 include/cutlass/gemm/threadblock/mma_base.h
+4 −2 include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h
+1 −1 include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h
+1 −1 media/docs/fundamental_types.md
+2 −2 tools/library/scripts/pycutlass/README.md
78 changes: 52 additions & 26 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import itertools
from contextlib import nullcontext
from functools import partial

import torch
Expand All @@ -19,13 +20,21 @@
SHAPES = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# Some smaller shapes as well
(4728, 1536, 1024),
# GPT-3 (small)
(32768, 2048, 5632),
# Chinchilla
(32768, 8192, 22016),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
# OP = xsw.SwiGLUFusedOp
OP = xsw.SwiGLUPackedFusedOp


def product_dict(**kwargs):
Expand All @@ -38,42 +47,51 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.bfloat16, torch.half, "autocast_half"],
)
)

DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}


def benchmark_swiglu(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True
)
.to(device)
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
params = module._ordered_params_for_op()

PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt="fn(x, *args)",
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt="fn(x)",
stmt=f"{PREFIX}fn(x)",
globals={
"x": x,
"fn": module,
Expand All @@ -85,26 +103,31 @@ def benchmark_swiglu(shape, dtype):


def benchmark_swiglu_bw(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
module = (
xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True
)
.to(device)
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
with cm():
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
Expand All @@ -117,10 +140,13 @@ def benchmark_swiglu_bw(shape, dtype):
)
del out

with cm():
out = module(x)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": module(x),
"out": out,
"grad": grad,
},
label="swiglu_bw",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"

namespace ge_spmm {

Expand Down

0 comments on commit 51c9861

Please sign in to comment.