Skip to content

Commit

Permalink
doc + removing seemingly niche options
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 12, 2022
1 parent 42629fa commit 0df1adc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
4 changes: 2 additions & 2 deletions examples/microGPT.py
Expand Up @@ -275,7 +275,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 128
BATCH = 256

WORKERS = 4
EPOCHS = 1
Expand Down Expand Up @@ -303,7 +303,7 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="compositional",
attention="scaled_dot_product",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
)
Expand Down
23 changes: 4 additions & 19 deletions xformers/components/attention/compositional.py
Expand Up @@ -48,7 +48,6 @@ class CompositionalAttentionConfig(AttentionConfig):
kdim: Optional[int] = None
vdim: Optional[int] = None
bias: bool = True
add_bias_kv: bool = False
causal: Optional[bool] = False


Expand All @@ -70,13 +69,15 @@ class CompositionalAttention(Attention):
Args:
num_heads: The number of heads *for the search operation*
dim_head: Latent space for a given head
dim_selection: dimension of the scoring/selection space for the retrievals
numn_rules: The number of rules to consider *for the retrieval operation*
dropout: attention dropout probability
qk_rule: QK product will drive the retrieval process
dim_selection: dimension of the scoring/selection space for the retrievals
nonlinear: use a non linear method to score the retrievals
dim_attn: dimension (embedding) of the attention
# FIXME: to be continued
kdim, vdim: dimensions of K and V, if different from Q
bias: use bias in the initial projection step
causal: causal computations (attend to the past only)
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
"""
Expand All @@ -95,7 +96,6 @@ def __init__(
kdim=None,
vdim=None,
bias=True,
add_bias_kv=False,
causal=False,
*_,
**__,
Expand Down Expand Up @@ -183,12 +183,6 @@ def __init__(
self.dim_selection + self.value_dim, 1, bias=bias
)

if add_bias_kv:
self.bias_k = nn.Parameter(Tensor(1, 1, dim_embed))
self.bias_v = nn.Parameter(Tensor(1, 1, dim_embed))
else:
self.bias_k = self.bias_v = None

self.causal = causal

self.reset_parameters()
Expand Down Expand Up @@ -219,10 +213,6 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)

def forward(
self,
Expand Down Expand Up @@ -281,11 +271,6 @@ def forward(
else:
att_mask_additive = self._causal_mask

if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.expand(-1, B, -1)])
v = torch.cat([v, self.bias_v.expand(-1, B, -1)])

# Flatten the heads or the rules
q = (
q.view(B, Sq, self.num_heads, self.dim_head)
Expand Down

0 comments on commit 0df1adc

Please sign in to comment.