Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwiGLU optimized fw/bw #490

Merged
merged 36 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
069405e
SwiGLU optimized fw/bw
Oct 24, 2022
4b317c6
Update on "SwiGLU optimized fw/bw"
Oct 24, 2022
11bad90
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
8b2f688
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
e1609de
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
30ca17c
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
eb9c553
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
ed2b7c2
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
e758435
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
3207254
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
dbf6092
Update on "SwiGLU optimized fw/bw"
Oct 25, 2022
acdf239
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
bbdc00e
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
5fe54aa
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
44a6fbf
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
d3e3089
Update on "SwiGLU optimized fw/bw"
Oct 26, 2022
db5770d
Update on "SwiGLU optimized fw/bw"
Oct 27, 2022
4c2bfdc
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
d2d0187
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
e2d97d2
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
7224112
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
06c1487
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
783a2ff
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
69e299f
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
f6e2ceb
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
538d05c
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
0ab305f
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
c67a0ad
Update on "SwiGLU optimized fw/bw"
Oct 28, 2022
a77aeec
Update on "SwiGLU optimized fw/bw"
Oct 31, 2022
4b600bf
Update on "SwiGLU optimized fw/bw"
Oct 31, 2022
dd6a285
Update on "SwiGLU optimized fw/bw"
Nov 3, 2022
d825314
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
e2bfbb2
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
07135b8
Update on "SwiGLU optimized fw/bw"
Nov 4, 2022
3490242
Update on "SwiGLU optimized fw/bw"
Nov 7, 2022
a90fe49
Update on "SwiGLU optimized fw/bw"
Nov 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ cpu_py38: &cpu_py38
docker:
- image: cimg/python:3.8
resource_class: large
environment:
# We're a bit short on RAM
MAX_JOBS: "4"
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved

gpu_cu114: &gpu_cu114
environment:
Expand Down
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
73 changes: 48 additions & 25 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.

import random
from typing import Optional
from contextlib import nullcontext
from typing import Optional, Sequence

import pytest
import torch
Expand All @@ -13,7 +14,13 @@

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
if torch.cuda.is_available():
_devices = ["cuda"]
_is_sm80 = torch.cuda.get_device_capability(_devices[0])[0] >= 8
else:
_devices = []
_is_sm80 = False
sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80")


def assert_allclose(
Expand Down Expand Up @@ -76,23 +83,32 @@ def generate_test_shapes():
shapes = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# GPT-3 (small)
(2048, 2048, 5632),
# Chinchilla
(2048, 8192, 22016),
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
]
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.bfloat16, torch.float16]
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("autocast", [False, True])
@pytest.mark.parametrize("pack_weights", [False, True])
@pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -103,29 +119,41 @@ def generate_test_shapes():
def test_forward_backward(
shape,
device,
op,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
torch.bfloat16: 4.0, # !!
}
BACKWARD_RTOL = {
torch.float: 2e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}

if device == "cpu" and dtype is not torch.float:
pytest.skip("Half not supported on CPU")
if autocast and (device == "cpu" or dtype is not torch.half):
pytest.skip("Autocast only supported for CUDA+Half")
if not op.supports(
xsw.SwiGLUOpDispatch(
device=device,
dtype=dtype,
dtype_autocast_gpu=dtype if autocast and device == "cuda" else None,
packed_weights=pack_weights,
)
):
pytest.skip("Not supported by operator")

inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -140,21 +168,16 @@ def test_forward_backward(
x.requires_grad_()

# Forward
if autocast:
with torch.autocast("cuda", dtype=dtype):
ref = module(x)
else:
cm = torch.autocast("cuda", dtype=dtype) if autocast else nullcontext()
with cm:
ref = module(x)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)

if ref_f32 is None:
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 27 files
+21 −13 examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu
+2 −2 examples/40_cutlass_py/README.md
+0 −5 examples/41_multi_head_attention/fused_multihead_attention.cu
+1 −1 examples/43_dual_gemm/CMakeLists.txt
+59 −37 examples/43_dual_gemm/device/dual_gemm.h
+92 −89 examples/43_dual_gemm/dual_gemm.cu
+101 −85 examples/43_dual_gemm/dual_gemm_run.h
+73 −155 examples/43_dual_gemm/kernel/dual_gemm.h
+0 −0 examples/43_dual_gemm/test_run.h
+150 −0 examples/43_dual_gemm/thread/left_silu_and_mul.h
+52 −27 examples/43_dual_gemm/threadblock/dual_epilogue.h
+0 −0 examples/43_dual_gemm/threadblock/dual_mma_base.h
+0 −0 examples/43_dual_gemm/threadblock/dual_mma_multistage.h
+0 −275 examples/50_dual_gemm/reference/device/tensor_scale_bias.h
+0 −648 examples/50_dual_gemm/threadblock/dual_mma_multistage2.h
+1 −1 examples/CMakeLists.txt
+1 −0 include/cutlass/epilogue/thread/activation.h
+2 −3 include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
+414 −0 include/cutlass/gemm/device/gemm_with_k_reduction.h
+1 −1 include/cutlass/gemm/kernel/default_gemm_with_k_reduction.h
+1 −0 include/cutlass/gemm/kernel/gemm_with_k_reduction.h
+1 −1 include/cutlass/gemm/threadblock/default_mma_core_with_reduction.h
+1 −0 include/cutlass/gemm/threadblock/mma_base.h
+4 −2 include/cutlass/gemm/warp/default_mma_with_reduction_tensor_op.h
+1 −1 include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h
+1 −1 media/docs/fundamental_types.md
+2 −2 tools/library/scripts/pycutlass/README.md
78 changes: 52 additions & 26 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import itertools
from contextlib import nullcontext
from functools import partial

import torch
Expand All @@ -19,13 +20,21 @@
SHAPES = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# Some smaller shapes as well
(4728, 1536, 1024),
# GPT-3 (small)
(32768, 2048, 5632),
# Chinchilla
(32768, 8192, 22016),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
# OP = xsw.SwiGLUFusedOp
OP = xsw.SwiGLUPackedFusedOp


def product_dict(**kwargs):
Expand All @@ -38,42 +47,51 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.bfloat16, torch.half, "autocast_half"],
)
)

DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}


def benchmark_swiglu(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True
)
.to(device)
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
params = module._ordered_params_for_op()

PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt="fn(x, *args)",
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt="fn(x)",
stmt=f"{PREFIX}fn(x)",
globals={
"x": x,
"fn": module,
Expand All @@ -85,26 +103,31 @@ def benchmark_swiglu(shape, dtype):


def benchmark_swiglu_bw(shape, dtype):
inp_dtype, model_dtype, autocast = dtype, dtype, False
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
module = (
xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True
)
.to(device)
.to(model_dtype)
)

dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}.get(dtype, dtype)
dtype_str = DTYPE2STR.get(dtype, dtype)
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
with cm():
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
Expand All @@ -117,10 +140,13 @@ def benchmark_swiglu_bw(shape, dtype):
)
del out

with cm():
out = module(x)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": module(x),
"out": out,
"grad": grad,
},
label="swiglu_bw",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved

namespace ge_spmm {

Expand Down