Skip to content

Commit

Permalink
if dropout > 0.0 disable Flash until pytorch fix. don't assert fail sigh
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Feb 2, 2023
1 parent d8b1a94 commit 1e87509
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions model.py
Expand Up @@ -49,9 +49,9 @@ def __init__(self, config):
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
if not self.flash:
print("WARNING: using slow attention, install PyTorch nightly for fast Flash Attention")
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
Expand All @@ -68,7 +68,6 @@ def forward(self, x):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
assert self.dropout == 0.0, "need dropout=0.0 for now, PyTorch team is working on fix in #92917"
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
else:
# manual implementation of attention
Expand Down

0 comments on commit 1e87509

Please sign in to comment.