Skip to content

Commit

Permalink
Now running
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Nov 3, 2021
1 parent 962db66 commit 5eddcff
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ my_runs.md

# Watchman config files
.watchmanconfig

# examples demo files
examples/input.txt
7 changes: 4 additions & 3 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ def top_k_logits(logits, k):
if __name__ == "__main__":
seed_everything(42)
REF_BATCH = 512
BATCH = 256 # adjust depending on the avaiable memory on your machine
BATCH = 512 # adjust depending on the avaiable memory on your machine
WORKERS = 8
EPOCHS = 2
EPOCHS = 1
BLOCK = 128
WARMUP = 20

Expand All @@ -298,10 +298,11 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="scaled_dot_product",
attention="nystrom",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
)
print(model)

trainer = Trainer(
gpus=1,
Expand Down
7 changes: 5 additions & 2 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._sputnik_sparse import SparseCS

# NOTE: Could do with a better option on when to use triton and not
_use_triton = True
_use_triton = torch.cuda.is_available()
if _use_triton:
try:
from xformers.triton.softmax import softmax as triton_softmax
Expand Down Expand Up @@ -195,7 +195,10 @@ def scaled_query_key_softmax(

# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
q = q / math.sqrt(k.size(-1))
att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask)

# Matmul with mask, if boolean
is_bool_mask = att_mask is not None and att_mask.dtype == torch.bool
att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask if is_bool_mask else None)

if att_mask is not None and att_mask.dtype != torch.bool:
att = att + att_mask
Expand Down
51 changes: 32 additions & 19 deletions xformers/components/attention/nystrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from typing import Optional

Expand All @@ -15,7 +16,11 @@
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import iterative_pinv, reshape_key_padding_mask
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)


@dataclass
Expand Down Expand Up @@ -165,30 +170,33 @@ def forward(
"""
batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}

if key_padding_mask is not None:
assert key_padding_mask.dtype == torch.bool
if key_padding_mask.dtype == torch.bool:
logging.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)

assert key_padding_mask is not None # mypy is drunk

if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)

assert key_padding_mask.size() == (batched_dim, 1, seq_len,), (
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)

if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._tril_mask(batched_dim, seq_len, seq_len)
mask = self._tril_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = (
key_padding_mask
if mask is None
else mask.logical_and(key_padding_mask)
)
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)

else:
Expand All @@ -197,18 +205,17 @@ def forward(

if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
or (seq_len, self.num_landmarks) != self.causal_mask_1.size()[1:]
):
self.causal_mask_1 = self._tril_mask(
batched_dim, seq_len, self.num_landmarks
).to(q.device)
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._tril_mask(
batched_dim, self.num_landmarks, self.num_landmarks
).to(q.device)
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._tril_mask(
batched_dim, self.num_landmarks, seq_len
).to(q.device)
batched_dim, self.num_landmarks, seq_len, **tt
)

mask_1: Optional[torch.Tensor] = self.causal_mask_1
mask_2: Optional[torch.Tensor] = self.causal_mask_2
Expand Down Expand Up @@ -258,5 +265,11 @@ def forward(
x = self.attn_drop(x)
return x

def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int):
return torch.tril(torch.ones(dim_1, dim_2, dim_3, dtype=torch.bool), diagonal=0)
def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]

return torch.tril(
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
diagonal=0,
).expand(dim_1, -1, -1)
12 changes: 12 additions & 0 deletions xformers/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,15 @@ def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=F
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
)
return v


def bool_mask_to_additive(
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
):
assert (
mask.dtype == torch.bool
), "This util is meant to convert in between bool masks and additive ones"

mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_

0 comments on commit 5eddcff

Please sign in to comment.