Skip to content

Commit

Permalink
Refactor the projection, align args on other attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 14, 2022
1 parent 79cdac9 commit 71a75c8
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 75 deletions.
5 changes: 2 additions & 3 deletions examples/microGPT.py
Expand Up @@ -68,7 +68,6 @@ def __init__(
"dropout": self.hparams.attn_pdrop,
"causal": True,
"seq_len": self.hparams.block_size,
"dim_head": self.hparams.n_embd // self.hparams.n_head,
"num_rules": self.hparams.n_head,
},
},
Expand Down Expand Up @@ -275,7 +274,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256
BATCH = 128

WORKERS = 4
EPOCHS = 1
Expand Down Expand Up @@ -303,7 +302,7 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="scaled_dot_product",
attention="compositional",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_attentions.py
Expand Up @@ -45,8 +45,8 @@ def _get_multihead(
"seq_len": SEQ,
"window_size": SEQ // 8 + 1, # local attention
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL // heads,
"num_rules": 2, # Compositional Attention
}

Expand Down
4 changes: 1 addition & 3 deletions tests/test_block_factory.py
Expand Up @@ -58,8 +58,8 @@ def test_xformer_encoder_block(
"window_size": SEQ // 8 + 1,
"seq_len": SEQ,
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL // heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
Expand Down Expand Up @@ -148,8 +148,6 @@ def test_xformer_decoder_block(
"window_size": SEQ // 8 + 1,
"seq_len": SEQ,
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"num_heads": heads,
"dim_head": MODEL / heads,
"layout": torch.eye(SEQ // block_size, SEQ // block_size, dtype=torch.long),
"block_size": block_size,
"num_rules": 2, # Compositional Attention
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compositional_attention.py
Expand Up @@ -61,8 +61,8 @@ def test_build_and_run(
"seq_len": SEQ,
"window_size": SEQ // 8 + 1, # local attention
"attention_query_mask": torch.rand((SEQ, 1)) < GLOBAL_ATTENTION_RATIO,
"dim_model": MODEL,
"num_heads": heads,
"dim_head": MODEL // heads,
"num_rules": 2, # Compositional Attention
"q_compose": q_compose,
"rules": rules,
Expand Down
5 changes: 5 additions & 0 deletions xformers/components/__init__.py
Expand Up @@ -48,6 +48,11 @@ def build_multi_head_attention(
"num_heads"
]

if "dim_model" not in multi_head_config["attention"]:
multi_head_config["attention"]["dim_model"] = multi_head_config[
"dim_model"
]

if (
"dim_features" not in multi_head_config["attention"]
or multi_head_config["attention"]["dim_features"] is None
Expand Down
122 changes: 60 additions & 62 deletions xformers/components/attention/compositional.py
Expand Up @@ -33,22 +33,30 @@
if _is_triton_available:
from xformers.triton.softmax import softmax

from xformers.components.in_proj_container import InProjContainer, InProjParams


def _either_or(a: Optional[int], b: int) -> int:
return a if a is not None else b


@dataclass
class CompositionalAttentionConfig(AttentionConfig):
dim_model: int
num_heads: int
dim_head: int
num_rules: Optional[int]
dim_attn: Optional[int] = None
num_rules: Optional[int] = None
dim_key: Optional[int] = None
dim_value: Optional[int] = None
dim_selection: Optional[int] = None
dropout: float
qk_rule: bool = False
dim_selection: Optional[int] = None
nonlinear: bool = False
q_compose: bool = False
dim_attn: Optional[int] = None
kdim: Optional[int] = None
vdim: Optional[int] = None
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InProjContainer] = None
use_separate_proj_weight: Optional[bool] = False


@register_attention("compositional", CompositionalAttentionConfig)
Expand All @@ -67,15 +75,15 @@ class CompositionalAttention(Attention):
may not fit in memory.
Args:
num_heads: The number of heads *for the search operation*
dim_head: Latent space for a given head
dim_model: dimension of the incoming latent space
num_heads: number of heads *for the search operation*
dim_attn: dimension (embedding) of the attention
num_rules: number of rules to consider *for the retrieval operation*
dim_selection: dimension of the scoring/selection space for the retrievals
numn_rules: The number of rules to consider *for the retrieval operation*
dim_key, dim_value: dimensions of K and V, if different from Q
dropout: attention dropout probability
qk_rule: QK product will drive the retrieval process
nonlinear: use a non linear method to score the retrievals
dim_attn: dimension (embedding) of the attention
kdim, vdim: dimensions of K and V, if different from Q
bias: use bias in the initial projection step
causal: causal computations (attend to the past only)
Expand All @@ -84,17 +92,19 @@ class CompositionalAttention(Attention):

def __init__(
self,
num_heads,
dim_head,
num_rules=None,
dim_model: int,
num_heads: int,
dim_attn: Optional[int] = None,
num_rules: Optional[int] = None,
dim_selection: Optional[int] = None,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
dropout=0.0,
qk_rule=False,
dim_selection=None,
nonlinear=False,
q_compose=False,
dim_attn=None,
kdim=None,
vdim=None,
in_proj_container: Optional[InProjContainer] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
*_,
Expand All @@ -103,28 +113,33 @@ def __init__(
super().__init__()

# Define the inherited flags
self.requires_input_projection = (
False # This attention handles its own projection
)

self.requires_skip_multi_head = (
True # This attention owns the multi-head mechanism
)

# Handle defaults / undefined values
num_rules = num_heads if num_rules is None else num_rules
dim_embed = int(num_heads * dim_head)
dim_attn = dim_embed if dim_attn is None else dim_attn
dim_selection = (
dim_embed // num_heads if dim_selection is None else dim_selection
)
self.dim_model = dim_model
num_rules = _either_or(num_rules, num_heads)
dim_selection = _either_or(dim_selection, dim_model // num_heads)

# All the initial definition plumbing
self.dim_embed = dim_embed
self.dim_attn = dim_attn
self.kdim = kdim if kdim is not None else dim_embed
self.vdim = vdim if vdim is not None else dim_embed
self.qkv_same_dim = self.kdim == dim_embed and self.vdim == dim_embed
dim_attn = _either_or(dim_attn, dim_model)
dim_key = _either_or(dim_key, dim_model)
dim_value = _either_or(dim_value, dim_model)

self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InProjContainer(
query_proj_params=InProjParams(dim_model, dim_key, bias=bias),
key_proj_params=InProjParams(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
else None,
value_proj_params=InProjParams(dim_model, dim_value, bias=bias)
if use_separate_proj_weight
else None,
)
)

self.num_heads = num_heads
self.num_rules = num_rules
Expand All @@ -134,42 +149,35 @@ def __init__(
self.q_compose = q_compose

self.dropout_module = nn.Dropout(dropout)
self.dim_head = dim_embed // num_heads
self.dim_head = dim_model // num_heads
self.value_dim = dim_attn // num_rules

assert (
self.dim_head * num_heads == self.dim_embed
), "dim_embed must be divisible by num_heads"

assert (
self.value_dim * num_rules == self.dim_attn
self.value_dim * num_rules == dim_attn
), "value_dim must be divisible by num_rules"

self.scaling = self.dim_head ** -0.5
self.scaling_values = self.dim_selection ** -0.5

self.k_proj = nn.Linear(self.kdim, dim_embed, bias=bias)
self.v_proj = nn.Linear(self.vdim, dim_attn, bias=bias)
self.q_proj = nn.Linear(dim_embed, dim_embed, bias=bias)
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_embed, bias=bias)
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)

if self.qk_rule:
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_embed, self.dim_selection * self.num_heads, bias=bias
dim_model, self.dim_selection * self.num_heads, bias=bias
)
else:
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_embed, self.dim_selection * self.num_heads, bias=bias
dim_model, self.dim_selection * self.num_heads, bias=bias
)
if self.nonlinear:
self.score_network = nn.Sequential(
self.score_network: nn.Module = nn.Sequential(
nn.Linear(
self.dim_selection + self.value_dim,
self.dim_selection,
Expand All @@ -185,19 +193,10 @@ def __init__(

self.causal = causal

self.reset_parameters()
self._reset_parameters()

def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
def _reset_parameters(self):
# NOTE: in_proj_container is already initialized

if self.qk_rule:
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
Expand Down Expand Up @@ -235,15 +234,14 @@ def forward(
B, Sq, E = q.shape
_, Sk, _ = k.shape

assert E == self.dim_embed
assert E == self.dim_model

# First define projected query/key/values
# We keep the projected and original tensors in flight,
# depending on the options the original values could be reused
q_unprojected = q
q = self.q_proj(q) * self.scaling
k = self.k_proj(k)
v = self.v_proj(v)
q, k, v = self.in_proj_container(query=q, key=k, value=v)
q *= self.scaling

# Init causal mask if needed, now that we know the context length
if self.causal and (
Expand Down
10 changes: 5 additions & 5 deletions xformers/components/multi_head_dispatch.py
Expand Up @@ -159,17 +159,17 @@ def forward(
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
)

if self.attention.requires_skip_multi_head:
return self.attention(
query, key, value, att_mask=att_mask, key_padding_mask=key_padding_mask
)

# Calculate query, key, values for all heads in batch
if self.attention.requires_input_projection:
q, k, v = self.in_proj_container(query=query, key=key, value=value)
else:
k, q, v = key, query, value

if self.attention.requires_skip_multi_head:
return self.attention(
q, k, v, att_mask=att_mask, key_padding_mask=key_padding_mask
)

# Optional: rotary embedding, add relative positioning information
if self.rotary_embeddings:
# rotary requires the head dimension
Expand Down

0 comments on commit 71a75c8

Please sign in to comment.