Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cutlass for memory-efficient attention #362

Merged
merged 59 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f526651
Enable masking in memory-efficient attention (#333)
fmassa May 12, 2022
b0a6c91
Enable dropout in memory-efficient attention (#334)
fmassa May 23, 2022
5de87c4
Fix masking corner case when full block is masked (#339)
fmassa May 24, 2022
267cc4e
Add cutlass 2.9 - 858c735856a7f17bd33fe438ec76d3c9f0234e7f
danthe3rd May 24, 2022
21ab567
Option to load from shared memory for PredicatedTileIterator
danthe3rd May 24, 2022
921c637
Add cutlass include dir
danthe3rd May 24, 2022
7078b1e
Ignore files in third-party for flake8/coverage
danthe3rd May 24, 2022
9079d7e
third-party -> third_party
danthe3rd May 24, 2022
01c8edc
Address comments
danthe3rd May 24, 2022
0282b39
Revert some un-needed mods
danthe3rd May 24, 2022
3d0e645
Add attention_forward_generic.cu
danthe3rd May 24, 2022
f6e0c8c
Add tests
danthe3rd May 24, 2022
4080f90
Fix duplicate calculations on baseline for mem efficient transformers
danthe3rd May 24, 2022
3426755
Always run all linters in CI
danthe3rd May 25, 2022
f8cb6d9
clang-format attention_forward_generic.cu
danthe3rd May 25, 2022
0b46be0
Benchmark: Add possibility to compare benchmarks
fmassa Aug 3, 2022
9611baa
[isort] Ignore third_party
danthe3rd May 25, 2022
f698a5e
black autoformat
danthe3rd May 25, 2022
1f26b59
Black again + ignore third_party properly
danthe3rd May 25, 2022
cbfef46
black
danthe3rd May 25, 2022
fd424e3
Fix memory leak between the 2 benchmarks in backward
danthe3rd May 25, 2022
9fb88bd
Exclude third_party/ without using pyproject.toml as it imposes isola…
danthe3rd May 25, 2022
f79c017
Remove progress bar when finished
danthe3rd May 25, 2022
fe5f615
mypy
danthe3rd May 25, 2022
216fa27
flake8
danthe3rd May 25, 2022
4fbe4e9
Save results to shared folder in home location
danthe3rd May 25, 2022
b1cd83c
run black
danthe3rd May 31, 2022
0d05f69
clang-format with 'run-clang-format.py'
danthe3rd May 31, 2022
feae957
Fix cutlass build for arch>=75
danthe3rd May 31, 2022
1907b68
Set tests precision for gradient more accurately
danthe3rd May 24, 2022
c32053f
Fix precision margin
danthe3rd May 31, 2022
3602c06
Revert changes to black
danthe3rd May 31, 2022
f187e25
[feat] Fix importing xformers when not built (#351)
danthe3rd May 31, 2022
c8d488e
Update black to 22.3.0
danthe3rd May 31, 2022
93a75b7
Tweak precision for mem_eff_attention test
danthe3rd Jun 1, 2022
256f2d4
mem-efficient impl for f16 (#352)
danthe3rd Jun 8, 2022
1cce7fc
Add support for f16 with tensorcores [sm70/sm75/sm80] (#354)
danthe3rd Jun 13, 2022
19d1cce
Optimize backward of memory-efficient attention by ~20% (#355)
fmassa Jun 17, 2022
a04fae4
Display results as we progress during benchmark (#357)
danthe3rd Jun 30, 2022
2568a84
RFC: Ops dispatch (#356)
danthe3rd Jun 30, 2022
71c2eab
[A100/f32] Use TensorCores for Q.K_t matmul with FastF32 (#358)
danthe3rd Jul 5, 2022
573ed14
FlashAttention implem and dispatch (#360)
danthe3rd Jul 7, 2022
36cf435
Misc performance improvements for generic mem-efficient attention (#361)
danthe3rd Jul 8, 2022
daac694
Update flashattention to support bf16 (#363)
danthe3rd Jul 12, 2022
ccf7d15
Flashattn causal (#364)
danthe3rd Jul 15, 2022
db0b9a7
Option to disable flashattention (long to build) (#362)
danthe3rd Jul 21, 2022
7d11238
Remove code duplicate in attention_scaling_coefs_updater.h (#367)
danthe3rd Jul 21, 2022
579eace
Update .gitmodules (#366)
danthe3rd Jul 25, 2022
8b61b0b
MemoryEff attention forward: Properly fuse matmul and enable TensorCo…
danthe3rd Jul 25, 2022
ff52718
Update install instructions with submodule (#365)
fmassa Jul 27, 2022
bb616fa
Generic backward implem with cutlass (#371)
danthe3rd Aug 1, 2022
67ecf34
Cutlass as submodule (#375)
fmassa Aug 4, 2022
4bc3588
Fix bad rebase
fmassa Aug 4, 2022
0de2f12
Bump tolerance for backward (#377)
fmassa Aug 10, 2022
4ef8439
Add verbose flag to CI builds (#376)
fmassa Aug 5, 2022
3946ab8
Fix for FlashAttention dispatch
fmassa Aug 10, 2022
b1dd378
Remove generated file
fmassa Aug 11, 2022
c601866
Address some reviewer feedback
fmassa Aug 12, 2022
1e17161
Perf improvement on backward (#378)
danthe3rd Aug 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 31 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ install_dep: &install_dep

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for all the CI changes, LGTM and pretty useful

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run into any numerical instability issues during training? Or, have you done any checks as was done

hi, @lucidrains , do you meet such precision loss situation during your training? I meet it when train SWIN-T.

# start installing
source activate /home/circleci/venv

# for faster builds
conda install ninja
echo "Ninja version $(ninja --version)"

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -q
$CONDA_PYTHON -m pip install -r requirements-benchmark.txt --progress-bar off

Expand All @@ -108,100 +113,118 @@ install_dep_exp: &install_dep_exp
install_repo: &install_repo
- run:
name: Install Repository
no_output_timeout: 30m
command: |
$CONDA_PYTHON -m pip install -e .
source $BASH_ENV
source activate /home/circleci/venv
git submodule update --init --recursive
$CONDA_PYTHON -m pip install -v -e .

# Test import.
$CONDA_PYTHON -c 'import sys; sys.path = sys.path[1:]; import xformers'

install_experimental_repo: &install_experimental_repo
- run:
name: Install Repository
no_output_timeout: 30m
command: |
git submodule update --init --recursive
source $BASH_ENV

cd experimental
$CONDA_PYTHON -m pip install -e .
$CONDA_PYTHON -m pip install -v -e .

run_isort: &run_isort
- run:
name: Run Linter (isort)
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m isort . --check --profile black

run_black: &run_black
- run:
name: Run Linter (black)
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m black --check .
$CONDA_PYTHON -m black --check . --exclude "third_party/"

run_mypy: &run_mypy
- run:
name: Run type-checking (mypy)
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m mypy --ignore-missing-imports --scripts-are-modules --pretty --exclude build/ --exclude stubs/ .
$CONDA_PYTHON -m mypy --ignore-missing-imports --scripts-are-modules --pretty --exclude "(build|stubs|third_party|docs|setup.py)" .

run_flake8: &run_flake8
- run:
name: Run Linter (flake8)
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m flake8 --config .flake8 --show-source --statistics

run_clang_format: &run_clang_format
- run:
name: Run Linter (clang-format)
when: always
command: |
# install clang-format here, so that it gets cached
sudo apt-get update
sudo apt-get install clang-format
clang-format --version

# apply to our files
./.circleci/run-clang-format.py -r xformers/components/attention/csrc

run_coverage: &run_coverage
- run:
name: Run Unit Tests With Coverage
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 --cov-report=xml --cov=./ tests
#Uploading test coverage for Python code
bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python

run_unittests: &run_unittests
- run:
name: Run Unit Tests
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 tests
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest --junitxml=test-results/junit.xml --verbose --timeout 600 tests

run_experimental_unittests: &run_experimental_unittests
- run:
name: Run Unit Tests
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON -m pytest experimental/tests
CUDA_LAUNCH_BLOCKING=1 $CONDA_PYTHON -m pytest experimental/tests

run_benchmarks: &run_benchmarks
- run:
name: Run Benchmarks
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON xformers/benchmarks/benchmark_encoder.py --activations gelu --plot -emb 128 -bs 16 -heads 4

run_pytorch_benchmark: &run_pytorch_benchmark
- run:
name: Run Pytorch benchmark
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON xformers/benchmarks/benchmark_pytorch_transformer.py

run_vit_benchmark: &run_vit_benchmark
- run:
name: Run ViT Timm benchmark
when: always
command: |
source $BASH_ENV
$CONDA_PYTHON xformers/benchmarks/benchmark_vit_timm.py
Expand All @@ -211,6 +234,7 @@ run_vit_benchmark: &run_vit_benchmark
run_doc_build: &run_doc_build
- run:
name: Testing doc build
when: always
command: |
source $BASH_ENV
cd docs
Expand Down
4 changes: 2 additions & 2 deletions .circleci/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def make_diff(file, original, reformatted):
difflib.unified_diff(
original,
reformatted,
fromfile="{}\t(original)".format(file),
tofile="{}\t(reformatted)".format(file),
fromfile="a/{}\t(original)".format(file),
tofile="b/{}\t(reformatted)".format(file),
n=3,
)
)
Expand Down
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ omit =
xformers/benchmarks/*
xformers/triton/k_*
stubs/*
third_party/*
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
exclude =
.git
,.circleci/run-clang-format.py
,third_party
max-line-length = 120
copyright-check = True
select = E,F,W,C
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ examples/data
# Hydra default output dir
multirun
outputs

.benchmarks
7 changes: 7 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/HazyResearch/flash-attention.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/fmassa/cutlass.git
branch = updates_for_mha
1 change: 1 addition & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[settings]
known_third_party =fvcore,hydra,input_pipeline,matplotlib,numpy,omegaconf,pandas,pl_bolts,pyre_extensions,pytest,pytorch_lightning,ragged_inference,recommonmark,seaborn,setuptools,sklearn,submitit,tensorflow,timm,torch,torchmetrics,torchvision,tqdm,triton,typing_extensions
skip_glob=third_party/*
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ There are two ways you can install xFormers locally:

```bash
git clone git@github.com:facebookresearch/xformers.git
git submodule update --init --recursive
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I was chacking that when seeing that xformers now has two submodules, perfect. Thanks

conda create --name xformer_env python=3.8
conda activate xformer_env
cd xformers
pip install -r requirements.txt
pip install -e .
Expand Down
99 changes: 96 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path

import setuptools
import torch
Expand Down Expand Up @@ -44,6 +46,84 @@ def find_version(version_file_path):
raise RuntimeError("Unable to find version string.")


def get_cuda_version(cuda_dir) -> int:
nvcc_bin = "nvcc" if cuda_dir is None else cuda_dir + "/bin/nvcc"
raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = int(release[0])
bare_metal_minor = int(release[1][0])

assert bare_metal_minor < 100
return bare_metal_major * 100 + bare_metal_minor


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version > 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
elif cuda_version >= 1100:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but cuda_version == 1100 in that case, right ?

DEFAULT_ARCHS_LIST = "7.5;8.0"
else:
return []

archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST)
nvcc_archs_flags = []
for arch in archs_list.split(";"):
assert len(arch) >= 3, f"Invalid sm version: {arch}"

num = 10 * int(arch[0]) + int(arch[2])
# Need at least 7.5
if num < 75:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we printout some warnings here (or in the main setup), to recap what's being built and possibly why ? I feel like there could be a lot of issues raised around that with the build process silently skipping flashattention because of an old cuda version and users not seeing it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll add some log messages

But in general, we need to improve on the packaging of xformers, specially now that a lot of hardware-specific kernels are being used. @bottler might look into improving this

nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}")
if arch.endswith("+PTX"):
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=compute_{num}")
if not nvcc_archs_flags:
return []

this_dir = os.path.dirname(os.path.abspath(__file__))
flash_root = os.path.join(this_dir, "third_party", "flash-attention")
return [
CUDAExtension(
name="xformers._C_flashattention",
sources=[
os.path.join(this_dir, "third_party", "flash-attention", path)
for path in [
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
]
],
extra_compile_args={
**extra_compile_args,
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
+ nvcc_archs_flags,
},
include_dirs=[
Path(flash_root) / "csrc" / "flash_attn",
Path(flash_root) / "csrc" / "flash_attn" / "src",
# Path(flash_root) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
Path(this_dir) / "third_party" / "cutlass" / "include",
],
)
]


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
Expand All @@ -57,9 +137,11 @@ def get_extensions():
)

sources = main_file + source_cpu

source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")

extension = CppExtension

Expand All @@ -73,31 +155,42 @@ def get_extensions():
extra_compile_args["cxx"].append("-fopenmp")

include_dirs = [extensions_dir]
ext_modules = []

if (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv(
"FORCE_CUDA", "0"
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir]
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = os.getenv("NVCC_FLAGS", "")
if nvcc_flags == "":
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(" ")
cuda_version = get_cuda_version(CUDA_HOME)
if cuda_version >= 1102:
nvcc_flags += ["--threads", "4", "--ptxas-options=-v"]
extra_compile_args["nvcc"] = nvcc_flags
if (
cuda_version >= 1100
and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0"
):
ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

sources = [os.path.join(extensions_dir, s) for s in sources]

ext_modules = [
ext_modules.append(
extension(
"xformers._C",
sorted(sources),
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
)

return ext_modules

Expand Down