Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 10, 2022
1 parent 9cf3c38 commit 88e0226
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Added
- Compositional Attention [#41]

## [0.0.8] - 2022-01-07
### Fixed
Expand Down
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -139,6 +139,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- See BigBird, Longformers,..
- [FourierMix](xformers/components/attention/fourier_mix.py)
- *[FNet: Mixing Tokens with Fourier Transforms, Lee-Thorp et al.](https://arxiv.org/abs/2105.03824v1)*
- [CompositionalAttention](xformers/components/attention/compositional.py)
- *[Compositional Attention: Disentangling search and retrieval, S. Mittal et al.](https://arxiv.org/pdf/2110.09419v1.pdf)*

- ... add a new one [see Contribution.md](CONTRIBUTING.md)

</p></details>
Expand Down
6 changes: 4 additions & 2 deletions examples/microGPT.py
Expand Up @@ -68,6 +68,8 @@ def __init__(
"dropout": self.hparams.attn_pdrop,
"causal": True,
"seq_len": self.hparams.block_size,
"dim_head": self.hparams.n_embd // self.hparams.n_head,
"num_rules": 2 * self.hparams.n_head,
},
},
"feedforward_config": {
Expand Down Expand Up @@ -273,7 +275,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256
BATCH = 32

WORKERS = 4
EPOCHS = 1
Expand Down Expand Up @@ -301,7 +303,7 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="scaled_dot_product",
attention="compositional",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_attentions.py
Expand Up @@ -43,10 +43,11 @@ def _get_multihead(
"dropout": attn_dropout,
"causal": causal,
"seq_len": SEQ,
"window_size": SEQ // 8 + 1,
"window_size": SEQ // 8 + 1, # local attention
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"num_heads": heads,
"dim_head": MODEL / heads,
"dim_head": MODEL // heads,
"num_rules": 2, # Compositional Attention
}

if skip_output_projection:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_block_factory.py
Expand Up @@ -59,9 +59,10 @@ def test_xformer_encoder_block(
"seq_len": SEQ,
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"num_heads": heads,
"dim_head": MODEL / heads,
"dim_head": MODEL // heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
}

multi_head_config = {
Expand Down Expand Up @@ -151,6 +152,7 @@ def test_xformer_decoder_block(
"dim_head": MODEL / heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
}

multi_head_config = {
Expand Down
8 changes: 7 additions & 1 deletion xformers/components/attention/base.py
Expand Up @@ -11,6 +11,8 @@
import torch
import torch.nn as nn

from xformers.components.attention import AttentionMask


@dataclass
class AttentionConfig:
Expand All @@ -29,7 +31,7 @@ class AttentionConfig:
class Attention(nn.Module, metaclass=ABCMeta):
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""

_causal_mask: Optional[torch.Tensor] = None
_causal_mask: Optional[AttentionMask] = None

@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
Expand All @@ -47,6 +49,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False

# Whether the attention owns the single head/multihead mechanism
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False

@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
Expand Down

0 comments on commit 88e0226

Please sign in to comment.