diff --git a/examples/microGPT.py b/examples/microGPT.py index 8214d6b06..d9f687e93 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -275,7 +275,7 @@ def top_k_logits(logits, k): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory REF_BATCH = 512 - BATCH = 128 + BATCH = 256 WORKERS = 4 EPOCHS = 1 @@ -303,7 +303,7 @@ def top_k_logits(logits, k): model = GPT( vocab_size=train_dataset.vocab_size, block_size=train_dataset.block_size, - attention="compositional", + attention="scaled_dot_product", warmup_tokens=REF_BATCH * WARMUP, final_tokens=EPOCHS * len(train_dataset) * BLOCK, ) diff --git a/xformers/components/attention/compositional.py b/xformers/components/attention/compositional.py index e1e11a4a1..b29c3bc92 100644 --- a/xformers/components/attention/compositional.py +++ b/xformers/components/attention/compositional.py @@ -48,7 +48,6 @@ class CompositionalAttentionConfig(AttentionConfig): kdim: Optional[int] = None vdim: Optional[int] = None bias: bool = True - add_bias_kv: bool = False causal: Optional[bool] = False @@ -70,13 +69,15 @@ class CompositionalAttention(Attention): Args: num_heads: The number of heads *for the search operation* dim_head: Latent space for a given head + dim_selection: dimension of the scoring/selection space for the retrievals numn_rules: The number of rules to consider *for the retrieval operation* dropout: attention dropout probability qk_rule: QK product will drive the retrieval process - dim_selection: dimension of the scoring/selection space for the retrievals nonlinear: use a non linear method to score the retrievals dim_attn: dimension (embedding) of the attention - # FIXME: to be continued + kdim, vdim: dimensions of K and V, if different from Q + bias: use bias in the initial projection step + causal: causal computations (attend to the past only) _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf """ @@ -95,7 +96,6 @@ def __init__( kdim=None, vdim=None, bias=True, - add_bias_kv=False, causal=False, *_, **__, @@ -183,12 +183,6 @@ def __init__( self.dim_selection + self.value_dim, 1, bias=bias ) - if add_bias_kv: - self.bias_k = nn.Parameter(Tensor(1, 1, dim_embed)) - self.bias_v = nn.Parameter(Tensor(1, 1, dim_embed)) - else: - self.bias_k = self.bias_v = None - self.causal = causal self.reset_parameters() @@ -219,10 +213,6 @@ def reset_parameters(self): nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) - if self.bias_k is not None: - nn.init.xavier_normal_(self.bias_k) - if self.bias_v is not None: - nn.init.xavier_normal_(self.bias_v) def forward( self, @@ -281,11 +271,6 @@ def forward( else: att_mask_additive = self._causal_mask - if self.bias_k is not None: - assert self.bias_v is not None - k = torch.cat([k, self.bias_k.expand(-1, B, -1)]) - v = torch.cat([v, self.bias_v.expand(-1, B, -1)]) - # Flatten the heads or the rules q = ( q.view(B, Sq, self.num_heads, self.dim_head)