From 71a75c82216097fe2f00cc9984461417b8e3490f Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Thu, 13 Jan 2022 19:35:41 -0800 Subject: [PATCH] Refactor the projection, align args on other attentions --- examples/microGPT.py | 5 +- tests/test_attentions.py | 2 +- tests/test_block_factory.py | 4 +- tests/test_compositional_attention.py | 2 +- xformers/components/__init__.py | 5 + .../components/attention/compositional.py | 122 +++++++++--------- xformers/components/multi_head_dispatch.py | 10 +- 7 files changed, 75 insertions(+), 75 deletions(-) diff --git a/examples/microGPT.py b/examples/microGPT.py index d9f687e93..26a890559 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -68,7 +68,6 @@ def __init__( "dropout": self.hparams.attn_pdrop, "causal": True, "seq_len": self.hparams.block_size, - "dim_head": self.hparams.n_embd // self.hparams.n_head, "num_rules": self.hparams.n_head, }, }, @@ -275,7 +274,7 @@ def top_k_logits(logits, k): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory REF_BATCH = 512 - BATCH = 256 + BATCH = 128 WORKERS = 4 EPOCHS = 1 @@ -303,7 +302,7 @@ def top_k_logits(logits, k): model = GPT( vocab_size=train_dataset.vocab_size, block_size=train_dataset.block_size, - attention="scaled_dot_product", + attention="compositional", warmup_tokens=REF_BATCH * WARMUP, final_tokens=EPOCHS * len(train_dataset) * BLOCK, ) diff --git a/tests/test_attentions.py b/tests/test_attentions.py index b6bfd7339..d01a99473 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -45,8 +45,8 @@ def _get_multihead( "seq_len": SEQ, "window_size": SEQ // 8 + 1, # local attention "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, + "dim_model": MODEL, "num_heads": heads, - "dim_head": MODEL // heads, "num_rules": 2, # Compositional Attention } diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index c73712b4a..af2076a7f 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -58,8 +58,8 @@ def test_xformer_encoder_block( "window_size": SEQ // 8 + 1, "seq_len": SEQ, "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, + "dim_model": MODEL, "num_heads": heads, - "dim_head": MODEL // heads, "layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long), "block_size": block_size, "num_rules": 2, # Compositional Attention @@ -148,8 +148,6 @@ def test_xformer_decoder_block( "window_size": SEQ // 8 + 1, "seq_len": SEQ, "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, - "num_heads": heads, - "dim_head": MODEL / heads, "layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long), "block_size": block_size, "num_rules": 2, # Compositional Attention diff --git a/tests/test_compositional_attention.py b/tests/test_compositional_attention.py index a206843cf..99484d20b 100644 --- a/tests/test_compositional_attention.py +++ b/tests/test_compositional_attention.py @@ -61,8 +61,8 @@ def test_build_and_run( "seq_len": SEQ, "window_size": SEQ // 8 + 1, # local attention "attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO, + "dim_model": MODEL, "num_heads": heads, - "dim_head": MODEL // heads, "num_rules": 2, # Compositional Attention "q_compose": q_compose, "rules": rules, diff --git a/xformers/components/__init__.py b/xformers/components/__init__.py index 2bf470714..bfe4e1281 100644 --- a/xformers/components/__init__.py +++ b/xformers/components/__init__.py @@ -48,6 +48,11 @@ def build_multi_head_attention( "num_heads" ] + if "dim_model" not in multi_head_config["attention"]: + multi_head_config["attention"]["dim_model"] = multi_head_config[ + "dim_model" + ] + if ( "dim_features" not in multi_head_config["attention"] or multi_head_config["attention"]["dim_features"] is None diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py index b29c3bc92..e7f871e4f 100644 --- a/xformers/components/attention/compositional.py +++ b/xformers/components/attention/compositional.py @@ -33,22 +33,30 @@ if _is_triton_available: from xformers.triton.softmax import softmax +from xformers.components.in_proj_container import InProjContainer, InProjParams + + +def _either_or(a: Optional[int], b: int) -> int: + return a if a is not None else b + @dataclass class CompositionalAttentionConfig(AttentionConfig): + dim_model: int num_heads: int - dim_head: int - num_rules: Optional[int] + dim_attn: Optional[int] = None + num_rules: Optional[int] = None + dim_key: Optional[int] = None + dim_value: Optional[int] = None + dim_selection: Optional[int] = None dropout: float qk_rule: bool = False - dim_selection: Optional[int] = None nonlinear: bool = False q_compose: bool = False - dim_attn: Optional[int] = None - kdim: Optional[int] = None - vdim: Optional[int] = None bias: bool = True causal: Optional[bool] = False + in_proj_container: Optional[InProjContainer] = None + use_separate_proj_weight: Optional[bool] = False @register_attention("compositional", CompositionalAttentionConfig) @@ -67,15 +75,15 @@ class CompositionalAttention(Attention): may not fit in memory. Args: - num_heads: The number of heads *for the search operation* - dim_head: Latent space for a given head + dim_model: dimension of the incoming latent space + num_heads: number of heads *for the search operation* + dim_attn: dimension (embedding) of the attention + num_rules: number of rules to consider *for the retrieval operation* dim_selection: dimension of the scoring/selection space for the retrievals - numn_rules: The number of rules to consider *for the retrieval operation* + dim_key, dim_value: dimensions of K and V, if different from Q dropout: attention dropout probability qk_rule: QK product will drive the retrieval process nonlinear: use a non linear method to score the retrievals - dim_attn: dimension (embedding) of the attention - kdim, vdim: dimensions of K and V, if different from Q bias: use bias in the initial projection step causal: causal computations (attend to the past only) @@ -84,17 +92,19 @@ class CompositionalAttention(Attention): def __init__( self, - num_heads, - dim_head, - num_rules=None, + dim_model: int, + num_heads: int, + dim_attn: Optional[int] = None, + num_rules: Optional[int] = None, + dim_selection: Optional[int] = None, + dim_key: Optional[int] = None, + dim_value: Optional[int] = None, dropout=0.0, qk_rule=False, - dim_selection=None, nonlinear=False, q_compose=False, - dim_attn=None, - kdim=None, - vdim=None, + in_proj_container: Optional[InProjContainer] = None, + use_separate_proj_weight: Optional[bool] = False, bias=True, causal=False, *_, @@ -103,28 +113,33 @@ def __init__( super().__init__() # Define the inherited flags - self.requires_input_projection = ( - False # This attention handles its own projection - ) - self.requires_skip_multi_head = ( True # This attention owns the multi-head mechanism ) # Handle defaults / undefined values - num_rules = num_heads if num_rules is None else num_rules - dim_embed = int(num_heads * dim_head) - dim_attn = dim_embed if dim_attn is None else dim_attn - dim_selection = ( - dim_embed // num_heads if dim_selection is None else dim_selection - ) + self.dim_model = dim_model + num_rules = _either_or(num_rules, num_heads) + dim_selection = _either_or(dim_selection, dim_model // num_heads) # All the initial definition plumbing - self.dim_embed = dim_embed - self.dim_attn = dim_attn - self.kdim = kdim if kdim is not None else dim_embed - self.vdim = vdim if vdim is not None else dim_embed - self.qkv_same_dim = self.kdim == dim_embed and self.vdim == dim_embed + dim_attn = _either_or(dim_attn, dim_model) + dim_key = _either_or(dim_key, dim_model) + dim_value = _either_or(dim_value, dim_model) + + self.in_proj_container = ( + in_proj_container + if in_proj_container is not None + else InProjContainer( + query_proj_params=InProjParams(dim_model, dim_key, bias=bias), + key_proj_params=InProjParams(dim_model, dim_key, bias=bias) + if use_separate_proj_weight + else None, + value_proj_params=InProjParams(dim_model, dim_value, bias=bias) + if use_separate_proj_weight + else None, + ) + ) self.num_heads = num_heads self.num_rules = num_rules @@ -134,24 +149,17 @@ def __init__( self.q_compose = q_compose self.dropout_module = nn.Dropout(dropout) - self.dim_head = dim_embed // num_heads + self.dim_head = dim_model // num_heads self.value_dim = dim_attn // num_rules assert ( - self.dim_head * num_heads == self.dim_embed - ), "dim_embed must be divisible by num_heads" - - assert ( - self.value_dim * num_rules == self.dim_attn + self.value_dim * num_rules == dim_attn ), "value_dim must be divisible by num_rules" self.scaling = self.dim_head ** -0.5 self.scaling_values = self.dim_selection ** -0.5 - self.k_proj = nn.Linear(self.kdim, dim_embed, bias=bias) - self.v_proj = nn.Linear(self.vdim, dim_attn, bias=bias) - self.q_proj = nn.Linear(dim_embed, dim_embed, bias=bias) - self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_embed, bias=bias) + self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias) if self.qk_rule: self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias) @@ -159,17 +167,17 @@ def __init__( self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) else: self.value_q = nn.Linear( - dim_embed, self.dim_selection * self.num_heads, bias=bias + dim_model, self.dim_selection * self.num_heads, bias=bias ) else: if self.q_compose: self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) else: self.value_q = nn.Linear( - dim_embed, self.dim_selection * self.num_heads, bias=bias + dim_model, self.dim_selection * self.num_heads, bias=bias ) if self.nonlinear: - self.score_network = nn.Sequential( + self.score_network: nn.Module = nn.Sequential( nn.Linear( self.dim_selection + self.value_dim, self.dim_selection, @@ -185,19 +193,10 @@ def __init__( self.causal = causal - self.reset_parameters() + self._reset_parameters() - def reset_parameters(self): - if self.qkv_same_dim: - # Empirically observed the convergence to be much better with - # the scaled initialization - nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) - nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) - else: - nn.init.xavier_uniform_(self.k_proj.weight) - nn.init.xavier_uniform_(self.v_proj.weight) - nn.init.xavier_uniform_(self.q_proj.weight) + def _reset_parameters(self): + # NOTE: in_proj_container is already initialized if self.qk_rule: nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) @@ -235,15 +234,14 @@ def forward( B, Sq, E = q.shape _, Sk, _ = k.shape - assert E == self.dim_embed + assert E == self.dim_model # First define projected query/key/values # We keep the projected and original tensors in flight, # depending on the options the original values could be reused q_unprojected = q - q = self.q_proj(q) * self.scaling - k = self.k_proj(k) - v = self.v_proj(v) + q, k, v = self.in_proj_container(query=q, key=k, value=v) + q *= self.scaling # Init causal mask if needed, now that we know the context length if self.causal and ( diff --git a/xformers/components/multi_head_dispatch.py b/xformers/components/multi_head_dispatch.py index 96b153b21..104062e8b 100644 --- a/xformers/components/multi_head_dispatch.py +++ b/xformers/components/multi_head_dispatch.py @@ -159,17 +159,17 @@ def forward( + "In that case causality is ill-determined. Please pad your sequences accordingly" ) + if self.attention.requires_skip_multi_head: + return self.attention( + query, key, value, att_mask=att_mask, key_padding_mask=key_padding_mask + ) + # Calculate query, key, values for all heads in batch if self.attention.requires_input_projection: q, k, v = self.in_proj_container(query=query, key=key, value=value) else: k, q, v = key, query, value - if self.attention.requires_skip_multi_head: - return self.attention( - q, k, v, att_mask=att_mask, key_padding_mask=key_padding_mask - ) - # Optional: rotary embedding, add relative positioning information if self.rotary_embeddings: # rotary requires the head dimension