Skip to content

Commit

Permalink
MHA: Stricter input validation
Browse files Browse the repository at this point in the history
ghstack-source-id: 9881e58fd13488c0b64aa4c9e4f070a49974bca3
Pull Request resolved: #592
  • Loading branch information
danthe3rd committed Dec 15, 2022
1 parent 89177a9 commit 9972795
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def _memory_efficient_attention_forward(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)

inp.validate_bmhk()
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)

Expand All @@ -321,6 +322,7 @@ def _memory_efficient_attention_forward_requires_grad(
raise ValueError(
f"xformers.memory_efficient_attention: Operator {op.NAME} does not support this input"
)
inp.validate_bmhk()
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
Expand Down
37 changes: 37 additions & 0 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,43 @@ def normalize_bmhk(self) -> Tuple[int, ...]:
self.value = self.value.unsqueeze(2)
return output_shape

def validate_bmhk(self) -> None:
qkv = (self.query, self.key, self.value)
if tuple(x.ndim for x in qkv) != (4, 4, 4):
raise ValueError(
f"Query/Key/Value should have BMHK format.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if any(x.device != self.query.device for x in qkv):
raise ValueError("Query/Key/Value should all be on the same device")
if any(x.dtype != self.query.dtype for x in qkv):
raise ValueError(
"Query/Key/Value should all have the same dtype\n"
f" query.dtype: {self.query.dtype}\n"
f" key.dtype : {self.key.dtype}\n"
f" value.dtype: {self.value.dtype}"
)
has_seqlen = any(isinstance(x, TensorWithSeqLen) for x in qkv)
if has_seqlen:
if not all(isinstance(x, TensorWithSeqLen) for x in qkv):
raise ValueError(
f"One of Query/Key/Value has sequence length information, but not all of them\n"
f" type(query): {type(self.query)}\n"
f" type(key) : {type(self.key)}\n"
f" type(value): {type(self.value)}"
)
if any(x.shape[0] != 1 for x in qkv):
raise ValueError(
f"Expected batch_size=1 when using sequence length information\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if self.p < 0.0 or self.p > 1.0:
raise ValueError(f"Invalid dropout probability: p={self.p}")


@dataclass
class Context:
Expand Down

0 comments on commit 9972795

Please sign in to comment.