Skip to content

Commit

Permalink
take care of flash attention for Muse, and make sure it supports the …
Browse files Browse the repository at this point in the history
…normed queries and keys with custom scale
  • Loading branch information
lucidrains committed Jul 15, 2023
1 parent 88e188b commit 371899e
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 18 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ images # List[PIL.Image.Image]
```

```bibtex
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```
126 changes: 126 additions & 0 deletions muse_maskgit_pytorch/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

# constants

AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

# helpers

def exists(val):
return val is not None

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attend(nn.Module):
def __init__(
self,
scale = 8,
dropout = 0.,
flash = False
):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = AttentionConfig(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = AttentionConfig(False, True, True)

def flash_attn(self, q, k, v, mask = None):
default_scale = q.shape[-1] ** -0.5

is_cuda = q.is_cuda

q, k, v = map(lambda t: t.contiguous(), (q, k, v))

# scaled_dot_product_attention does not allow for custom scale
# so hack it in, to support rmsnorm-ed queries and keys

rescale = self.scale / default_scale

q = q * (rescale ** 0.5)
k = k * (rescale ** 0.5)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(self, q, k, v, mask = None, force_non_flash = False):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

if self.flash and not force_non_flash:
return self.flash_attn(q, k, v, mask = mask)

# similarity

sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

# masking

if exists(mask):
mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, mask_value)

# attention

attn = sim.softmax(dim = -1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum("b h i j, b h j d -> b h i d", attn, v)

return out
29 changes: 17 additions & 12 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
from muse_maskgit_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from muse_maskgit_pytorch.attend import Attend

from tqdm.auto import tqdm

Expand Down Expand Up @@ -94,7 +95,9 @@ def __init__(
dim_head = 64,
heads = 8,
cross_attend = False,
scale = 8
scale = 8,
flash = True,
dropout = 0.
):
super().__init__()
self.scale = scale
Expand All @@ -104,6 +107,12 @@ def __init__(
self.cross_attend = cross_attend
self.norm = LayerNorm(dim)

self.attend = Attend(
flash = flash,
dropout = dropout,
scale = scale
)

self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))

self.to_q = nn.Linear(dim, inner_dim, bias = False)
Expand All @@ -122,6 +131,7 @@ def forward(
):
assert not (exists(context) ^ self.cross_attend)

n = x.shape[-2]
h, is_cross_attn = self.heads, exists(context)

x = self.norm(x)
Expand All @@ -142,17 +152,11 @@ def forward(
q = q * self.q_scale
k = k * self.k_scale

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if exists(context_mask):
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
context_mask = repeat(context_mask, 'b j -> b h i j', h = h, i = n)
context_mask = F.pad(context_mask, (1, 0), value = True)

mask_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~context_mask, mask_value)

attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = self.attend(q, k, v, mask = context_mask)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand All @@ -165,15 +169,16 @@ def __init__(
depth,
dim_head = 64,
heads = 8,
ff_mult = 4
ff_mult = 4,
flash = True
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads),
Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True),
Attention(dim = dim, dim_head = dim_head, heads = heads, flash = flash),
Attention(dim = dim, dim_head = dim_head, heads = heads, cross_attend = True, flash = flash),
FeedForward(dim = dim, mult = ff_mult)
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.2.2',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 371899e

Please sign in to comment.