diff --git a/CHANGELOG.md b/CHANGELOG.md index 46a30e15e..28b4db658 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## TBD +### Added +- Compositional Attention [#41] ## [0.0.8] - 2022-01-07 ### Fixed diff --git a/README.md b/README.md index 32b991b68..b2aa27174 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* - See BigBird, Longformers,.. - [FourierMix](xformers/components/attention/fourier_mix.py) - *[FNet: Mixing Tokens with Fourier Transforms, Lee-Thorp et al.](https://arxiv.org/abs/2105.03824v1)* +- [CompositionalAttention](xformers/components/attention/compositional.py) + - *[Compositional Attention: Disentangling search and retrieval, S. Mittal et al.](https://arxiv.org/pdf/2110.09419v1.pdf)* + - ... add a new one [see Contribution.md](CONTRIBUTING.md)

diff --git a/examples/microGPT.py b/examples/microGPT.py index fe664d1c1..7629f8aea 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -68,6 +68,8 @@ def __init__( "dropout": self.hparams.attn_pdrop, "causal": True, "seq_len": self.hparams.block_size, + "dim_head": self.hparams.n_embd // self.hparams.n_head, + "num_rules": 2 * self.hparams.n_head, }, }, "feedforward_config": { @@ -273,7 +275,7 @@ def top_k_logits(logits, k): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory REF_BATCH = 512 - BATCH = 256 + BATCH = 32 WORKERS = 4 EPOCHS = 1 @@ -301,7 +303,7 @@ def top_k_logits(logits, k): model = GPT( vocab_size=train_dataset.vocab_size, block_size=train_dataset.block_size, - attention="scaled_dot_product", + attention="compositional", warmup_tokens=REF_BATCH * WARMUP, final_tokens=EPOCHS * len(train_dataset) * BLOCK, ) diff --git a/tests/test_attentions.py b/tests/test_attentions.py index ffa8bee78..b6bfd7339 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -43,10 +43,11 @@ def _get_multihead( "dropout": attn_dropout, "causal": causal, "seq_len": SEQ, - "window_size": SEQ // 8 + 1, + "window_size": SEQ // 8 + 1, # local attention "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, "num_heads": heads, - "dim_head": MODEL / heads, + "dim_head": MODEL // heads, + "num_rules": 2, # Compositional Attention } if skip_output_projection: diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 24fa0bd7b..c73712b4a 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -59,9 +59,10 @@ def test_xformer_encoder_block( "seq_len": SEQ, "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, "num_heads": heads, - "dim_head": MODEL / heads, + "dim_head": MODEL // heads, "layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long), "block_size": block_size, + "num_rules": 2, # Compositional Attention } multi_head_config = { @@ -151,6 +152,7 @@ def test_xformer_decoder_block( "dim_head": MODEL / heads, "layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long), "block_size": block_size, + "num_rules": 2, # Compositional Attention } multi_head_config = { diff --git a/xformers/components/attention/base.py b/xformers/components/attention/base.py index 4728a0db2..511dcd04f 100644 --- a/xformers/components/attention/base.py +++ b/xformers/components/attention/base.py @@ -11,6 +11,8 @@ import torch import torch.nn as nn +from xformers.components.attention import AttentionMask + @dataclass class AttentionConfig: @@ -29,7 +31,7 @@ class AttentionConfig: class Attention(nn.Module, metaclass=ABCMeta): r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention""" - _causal_mask: Optional[torch.Tensor] = None + _causal_mask: Optional[AttentionMask] = None @abstractmethod def __init__(self, dropout: Optional[float] = None, *args, **kwargs): @@ -47,6 +49,10 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs): # Requires that K and Q have the same sequence length self.requires_same_k_q_dimensions = False + # Whether the attention owns the single head/multihead mechanism + # so that the MHA wrapper should skip it + self.requires_skip_multi_head = False + @classmethod def from_config(cls: Type[Self], config: AttentionConfig) -> Self: # Generate the class inputs from the config diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py new file mode 100644 index 000000000..8d8129c47 --- /dev/null +++ b/xformers/components/attention/compositional.py @@ -0,0 +1,410 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Credits: this is heavily inspired by the official implementation, present in +# https://github.com/sarthmit/Compositional-Attention +# Original author: Sarthak Mittal + +# This is a simplified version, for the sake of clarity, and because some features could be exposed later +# via the library directly. +# In particular, code paths for TPUs, quantization and gumbel softmax have been removed +# We're also following the same dimension ordering as in the rest of the xformers library +# which is to say [Batch, Sequence, Embedding] wherever possible + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) + + +@dataclass +class CompositionalAttentionConfig(AttentionConfig): + num_heads: int + dim_head: int + num_rules: Optional[int] + dropout: float + qk_rule: bool = False + selection_dim: Optional[int] = None + nonlinear: bool = False + q_compose: bool = False + attn_dim: Optional[int] = None + kdim: Optional[int] = None + vdim: Optional[int] = None + bias: bool = True + add_bias_kv: bool = False + add_zero_attn: bool = False + causal: Optional[bool] = False + + +@register_attention("compositional", CompositionalAttentionConfig) +class CompositionalAttention(Attention): + """Compositional Attention, as proposed in + "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al. + + _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf + """ + + def __init__( + self, + num_heads, + dim_head, + num_rules=None, + dropout=0.0, + qk_rule=False, + selection_dim=None, + nonlinear=False, + q_compose=False, + attn_dim=None, + kdim=None, + vdim=None, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + causal=False, + *_, + **__, + ): + super().__init__() + + # Define the inherited flags + self.requires_input_projection = ( + False # This attention handles its own projection + ) + + # FIXME: We fold and unfold here + self.requires_dim_headension = ( + False # This attention requires the heads to be visible in the inputs + ) + + self.requires_separate_masks = ( + True # Key and Attention masks are passed seperately + ) + + self.requires_skip_multi_head = ( + True # This attention owns the multi-head mechanism + ) + + # Handle defaults / undefined values + num_rules = ( + num_heads if num_rules is None else num_rules + ) # Does not make a lot of sense #FIXME + embed_dim = int(num_heads * dim_head) + attn_dim = embed_dim if attn_dim is None else attn_dim + selection_dim = ( + embed_dim // num_heads if selection_dim is None else selection_dim + ) + + # All the initial definition plumbing + self.embed_dim = embed_dim + self.attn_dim = attn_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.num_rules = num_rules + self.qk_rule = qk_rule + self.selection_dim = selection_dim + self.nonlinear = nonlinear + self.q_compose = q_compose + + self.dropout_module = nn.Dropout(dropout) + self.dim_head = embed_dim // num_heads + self.value_dim = attn_dim // num_rules + + assert ( + self.dim_head * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + assert ( + self.value_dim * num_rules == self.attn_dim + ), "embed_dim must be divisible by num_heads" + + self.scaling = self.dim_head ** -0.5 + self.scaling_values = self.selection_dim ** -0.5 + + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, attn_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(self.num_heads * self.value_dim, embed_dim, bias=bias) + + if self.qk_rule: + self.value_k = nn.Linear(self.value_dim, self.selection_dim, bias=bias) + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.selection_dim, bias=bias) + else: + self.value_q = nn.Linear( + embed_dim, self.selection_dim * self.num_heads, bias=bias + ) + else: + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.selection_dim, bias=bias) + else: + self.value_q = nn.Linear( + embed_dim, self.selection_dim * self.num_heads, bias=bias + ) + if self.nonlinear: + # Can change the capacity of the score_network MLP here + self.score_network1 = nn.Linear( + self.selection_dim + self.value_dim, + self.selection_dim, + bias=bias, + ) + self.score_network2 = nn.Linear(self.selection_dim, 1, bias=bias) + else: + self.score_network = nn.Linear( + self.selection_dim + self.value_dim, 1, bias=bias + ) + + if add_bias_kv: + self.bias_k = nn.Parameter(Tensor(1, 1, embed_dim)) + self.bias_v = nn.Parameter(Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.causal = causal + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + if self.qk_rule: + nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.value_q.weight) + if self.nonlinear: + nn.init.xavier_uniform_(self.score_network1.weight) + nn.init.xavier_uniform_(self.score_network2.weight) + else: + nn.init.xavier_uniform_(self.score_network.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + att_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + *args, + **kwargs, + ) -> Tensor: + """ + Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, Sk)`, where + padding elements are indicated by 1s. + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + + B, Sq, E = q.shape + assert E == self.embed_dim + _, Sk, _ = k.shape + + # First define projected query/key/values + # We keep the projected and original tensors in flight, + # depending on the options the original values could be reused + q_unprojected = q + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + q *= self.scaling + + # Init causal mask if needed, now that we konw the context length + if self.causal and ( + self._causal_mask is None or self._causal_mask.shape[0] != Sk + ): + self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device) + + # Convenience, create an attention mask if a tensor was passed + # This sanitizes different mask types being passed, from now on it's additive + if isinstance(att_mask, torch.Tensor): + # By default we don't know of the causality, and a check would be expensive + att_mask = ( + AttentionMask.from_bool(att_mask) + if att_mask.dtype == torch.bool + else AttentionMask(att_mask, is_causal=False) + ) + + # Handle the attention and key padding masks + if self._causal_mask is not None: + # Optionally add the causal mask + if att_mask is not None: + att_mask += self._causal_mask + else: + att_mask = self._causal_mask + + # FIXME: This extends the mask by 1, not sure why + # if att_mask is not None # and self.add_zero_attn: ? + # att_mask = torch.cat( + # [att_mask, att_mask.new_zeros(att_mask.size(0), 1)], dim=1 + # ) + + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, B, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, B, 1)]) + + # Flatten the heads or the rules + q = ( + q.view(B, Sq, self.num_heads, self.dim_head) + .movedim(2, 1) + .flatten(0, 1) # [B * num_heads, Sq, dim_head] + ) + k = ( + k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1) + ) # [B * num_heads, Sk, dim_head] + v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == B + assert key_padding_mask.size(1) == Sk + + if self.add_zero_attn: + Sk += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [B * self.num_heads, Sq, Sk] + + if att_mask is not None: + attn_weights += att_mask.values + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) + + attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + v = v.view(B, 1, self.num_rules, Sk, self.value_dim) + attn_probs = attn_probs.unsqueeze(2) + attn = torch.matmul(attn_probs, v).view( + B * self.num_heads * self.num_rules, Sq, self.value_dim + ) + assert list(attn.size()) == [ + B * self.num_heads * self.num_rules, + Sq, + self.value_dim, + ] + + attn = attn.view(B, self.num_heads, self.num_rules, Sq, self.value_dim).movedim( + 3, 1 + ) + + assert list(attn.size()) == [ + B, + Sq, + self.num_heads, + self.num_rules, + self.value_dim, + ] + + if self.q_compose: + v_q = self.value_q(q.transpose(0, 1)).view( + B, Sq, self.num_heads, 1, self.selection_dim + ) + else: + v_q = self.value_q(q_unprojected).view( + B, Sq, self.num_heads, 1, self.selection_dim + ) + + if self.qk_rule: + v_q *= self.scaling_values + v_k = ( + self.value_k(attn) + .view(B, Sq, self.num_heads, self.num_rules, self.selection_dim) + .transpose(4, 3) + .contiguous() + ) + v_score = torch.matmul(v_q, v_k).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + else: + v_q = v_q.repeat(1, 1, 1, self.num_rules, 1) + v_in = torch.cat([attn, v_q], dim=-1) + if self.nonlinear: + v_score = self.score_network1(v_in).view( + B, Sq, self.num_heads, self.num_rules, self.selection_dim + ) + v_score = self.score_network2(F.relu(v_score)) + else: + v_score = self.score_network(v_in).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + + v_score = F.softmax(v_score, dim=3) + + attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim) + attn = self.out_proj(attn) + + # TODO: Add these outputs across the lib, needed for other attention mechanisms + # if need_weights: + # attn_weights = attn_weights_float.view( + # B, self.num_heads, Sq, Sk + # ).transpose(1, 0) + # if not need_head_weights: + # # average attention weights over heads + # attn_weights = attn_weights.mean(dim=0) + + return attn # , attn_weights diff --git a/xformers/components/attention/favor.py b/xformers/components/attention/favor.py index 88be336b7..9eead62b3 100644 --- a/xformers/components/attention/favor.py +++ b/xformers/components/attention/favor.py @@ -50,7 +50,8 @@ def __init__( **__, ): r""" - Kernelized attention, as proposed in Performers_. + Kernelized attention, as proposed in Performers_ + ("Rethinking attention with performers." K. Choromanski et al. (2020).). FAVOR stands for "Fast Attention Via positive Orthogonal Random features" @@ -61,8 +62,7 @@ def __init__( feature_map_type (FeatureMapType): the type of feature map being used, for instance orthogonal random features. - .. _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). - https://arxiv.org/pdf/2009.14794v1.pdf + .. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf """ super().__init__() diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index 8e83dd629..96b153b21 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -127,6 +127,7 @@ def forward( key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, att_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Expected input dimensions are [batch size, sequence length, embed dim] @@ -164,6 +165,11 @@ def forward( else: k, q, v = key, query, value + if self.attention.requires_skip_multi_head: + return self.attention( + q, k, v, att_mask=att_mask, key_padding_mask=key_padding_mask + ) + # Optional: rotary embedding, add relative positioning information if self.rotary_embeddings: # rotary requires the head dimension @@ -187,7 +193,9 @@ def forward( v = reshape_fn(v, B, S_K, self.num_heads, self.dim_k) # Self-attend - y = self.attention(q=q, k=k, v=v, att_mask=att_mask) + y = self.attention( + q=q, k=k, v=v, att_mask=att_mask, key_padding_mask=key_padding_mask + ) # Re-assemble all head outputs side by side y = (