Skip to content

Commit

Permalink
Fix nits/comments on "MHA: Stricter input validation"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Dec 15, 2022
2 parents 1652dd2 + 8a44d7c commit 3660e0f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ on:
push:
branches:
- main
tags:
- "[0-9]+.[0-9]+.[0-9]+"

# this yaml file can be cleaned up using yaml anchors, but they're not supported in github actions yet
# https://github.com/actions/runner/issues/1182
Expand All @@ -15,6 +17,7 @@ env:
MAX_JOBS: 1 # will crash otherwise
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"
TWINE_USERNAME: __token__

jobs:
build_wheels:
Expand Down Expand Up @@ -112,8 +115,10 @@ jobs:

- name: Define version
run: |
set -Eeuo pipefail
git config --global --add safe.directory "*"
echo BUILD_VERSION=`$PY packaging/compute_rc_version.py` >> ${GITHUB_ENV}
version=`$PY packaging/compute_wheel_version.py`
echo "BUILD_VERSION=$version" >> ${GITHUB_ENV}
cat ${GITHUB_ENV}
- name: Setup proper pytorch dependency in "requirements.txt"
Expand All @@ -139,7 +144,6 @@ jobs:
if: matrix.config.publish
run: $PY -m twine upload dist/*.whl
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}

- name: Upload source distribution to PyPi
Expand All @@ -149,6 +153,5 @@ jobs:
$PY setup.py sdist -d sdist/
$PY -m twine upload sdist/*
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
# Note: it might be helpful to have additional steps that test if the built wheels actually work
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
# TODO: consolidate with the code in build_conda.py
THIS_PATH = Path(__file__).resolve()
version = (THIS_PATH.parents[1] / "version.txt").read_text().strip()


try:
tag = subprocess.check_output(["git", "describe", "--tags"], text=True).strip()
except subprocess.CalledProcessError: # no tag
tag = ""

if tag:
assert version == tag, "The version in version.txt does not match the given tag"
print(tag, end="")
exit(0)


num_commits = subprocess.check_output(
["git", "rev-list", "--count", "HEAD"], text=True
).strip()
Expand Down
5 changes: 3 additions & 2 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def _memory_efficient_attention(
def _memory_efficient_attention_forward(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> torch.Tensor:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp)
Expand All @@ -307,22 +308,21 @@ def _memory_efficient_attention_forward(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)

inp.validate_bmhk()
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)


def _memory_efficient_attention_forward_requires_grad(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> Tuple[torch.Tensor, Context]:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp)
elif not op.supports(inp):
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)
inp.validate_bmhk()
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
Expand All @@ -332,6 +332,7 @@ def _memory_efficient_attention_backward(
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
) -> Gradients:
"""Warning: grad/ctx.out is potentially in BMK format"""
inp.validate_inputs()
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
raise ValueError(
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
Expand Down
6 changes: 3 additions & 3 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def normalize_bmhk(self) -> Tuple[int, ...]:
self.value = self.value.unsqueeze(2)
return output_shape

def validate_bmhk(self) -> None:
def validate_inputs(self) -> None:
qkv = (self.query, self.key, self.value)
if tuple(x.ndim for x in qkv) != (4, 4, 4):
if self.query.ndim not in (3, 4) or any(x.ndim != self.query.ndim for x in qkv):
raise ValueError(
f"Query/Key/Value should have BMHK format.\n"
f"Query/Key/Value should all have BMHK or BMK shape.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class FwOp(AttentionFwOpBase):
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
SUPPORTED_MAX_K = math.inf
SUPPORTED_MAX_K = 65536
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), LowerTriangularMask}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
Expand Down

0 comments on commit 3660e0f

Please sign in to comment.