Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference cache #3

Open
wants to merge 15 commits into
base: weight-sharing
Choose a base branch
from
37 changes: 28 additions & 9 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def apply_pos_emb(pos_emb, qkv):
# classes

class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0., stable = False,
static_mask = None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
Expand All @@ -46,41 +47,53 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou

self.stable = stable
self.causal = causal
self.register_buffer('static_mask', static_mask, persistent=False)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask = None, rotary_pos_emb = None):
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax if not self.stable else stable_softmax
offset = cache.get('offset', 0) if exists(cache) else 0

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

if exists(rotary_pos_emb):
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))

q = q * self.scale

dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
if offset > 0:
k_top, v_top = cache[cache_key]
k = torch.cat([k_top, k], dim=-2)
v = torch.cat([v_top, v], dim=-2)
if exists(cache):
cache[cache_key] = k, v

dots = q @ k.swapaxes(-1, -2)
mask_value = max_neg_value(dots)

if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask

if self.causal:
if self.causal and offset == 0: # causality is naturally enforced for the cached inference
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)

if exists(self.static_mask):
dots.masked_fill_(~self.static_mask[offset:offset + n, :offset + n], mask_value)

attn = softmax(dots, dim=-1)

out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = attn @ v
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
Expand Down Expand Up @@ -109,7 +122,13 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
nn.Dropout(dropout)
)

def forward(self, x, mask = None, rotary_pos_emb = None):
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
n0 = x.shape[1]
if exists(cache):
if cache_key in cache:
x = torch.cat([cache[cache_key], x], dim=-2)
cache[cache_key] = x

b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

Expand Down Expand Up @@ -204,7 +223,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):

out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
out = self.to_out(out)
return out[:, :n]
return out[:, n - n0:n]

# sparse axial causal attention

Expand All @@ -229,7 +248,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
nn.Dropout(dropout)
)

def forward(self, x, mask = None, rotary_pos_emb = None):
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
softmax = torch.softmax if not self.stable else stable_softmax

Expand Down
18 changes: 15 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def __init__(
shared_attn_ids = None,
shared_ff_ids = None,
share_input_output_emb = False,
use_static_masks = False,
):
super().__init__()
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
Expand Down Expand Up @@ -391,6 +392,7 @@ def __init__(
rotary_emb = rotary_emb,
shared_attn_ids = shared_attn_ids,
shared_ff_ids = shared_ff_ids,
use_static_masks = use_static_masks,
)

self.stable = stable
Expand Down Expand Up @@ -484,7 +486,8 @@ def generate_images(
filter_thres = 0.5,
temperature = 1.,
img = None,
num_init_img_tokens = None
num_init_img_tokens = None,
use_cache = False,
):
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len
Expand All @@ -503,12 +506,13 @@ def generate_images(
indices = indices[:, :num_img_tokens]
out = torch.cat((out, indices), dim = -1)

cache = {} if use_cache else None
for cur_len in range(out.shape[1], total_len):
is_image = cur_len >= text_seq_len

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

logits = self(text, image, mask = mask)[:, -1, :]
logits = self(text, image, mask = mask, cache = cache)[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim = -1)
Expand Down Expand Up @@ -536,6 +540,7 @@ def forward(
text,
image = None,
mask = None,
cache = None,
return_loss = False
):
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
Expand Down Expand Up @@ -584,7 +589,9 @@ def forward(
alpha = 0.1
tokens = tokens * alpha + tokens.detach() * (1 - alpha)

out = self.transformer(tokens)
if exists(cache) and cache.get('offset'):
tokens = tokens[:, -1:]
out = self.transformer(tokens, cache=cache)

if self.stable:
out = self.norm_by_max(out)
Expand All @@ -594,9 +601,14 @@ def forward(
# mask logits to make sure text predicts text (except last token), and image predicts image

logits_mask = self.logits_mask[:, :seq_len]
if exists(cache) and cache.get('offset'):
logits_mask = logits_mask[:, -1:]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(logits_mask, max_neg_value)

if exists(cache):
cache['offset'] = cache.get('offset', 0) + logits.shape[1]

if not return_loss:
return logits

Expand Down
103 changes: 92 additions & 11 deletions dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import deque
from collections.abc import Iterable
from functools import partial
from itertools import islice, cycle
Expand Down Expand Up @@ -35,6 +36,15 @@ def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True)
return x / maxes

class CachedAs(nn.Module):
def __init__(self, cache_key, fn):
super().__init__()
self.cache_key = cache_key
self.fn = fn

def forward(self, x, *, cache=None, **kwargs):
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
Expand Down Expand Up @@ -83,7 +93,7 @@ def __init__(self, dim, dropout = 0., mult = 4.):
nn.Linear(dim * mult, dim)
)

def forward(self, x):
def forward(self, x, cache=None, cache_key=None):
return self.net(x)

# token shift classes
Expand All @@ -94,12 +104,30 @@ def __init__(self, fn, image_size, seq_len):
self.fn = fn
self.image_size = image_size
self.seq_len = seq_len
self.img_seq_len = image_size ** 2
self.text_len = seq_len - self.img_seq_len + 1

def forward(self, x, cache=None, cache_key=None, **kwargs):
seq_len, image_size, text_len = self.seq_len, self.image_size, self.text_len

if exists(cache) and cache_key in cache:
offset = cache['offset']
assert offset >= text_len, "cached inference for text is not supported"
q = cache[cache_key]
assert isinstance(q, deque) and len(q) == image_size

x_top, x_left, *x_pass = x[:, -1].chunk(4, dim=-1)

q.append((x_top, x_left))
x_top = q.popleft()[0]
x_left = q[-2][1]
if (offset - text_len) % image_size == 0:
x_left = torch.zeros_like(x_left)

x = torch.cat((x_top, x_left, *x_pass), dim=-1)
return self.fn(x[:, None], cache=cache, **kwargs)

def forward(self, x, **kwargs):
n = x.shape[1]
seq_len, image_size = self.seq_len, self.image_size
img_seq_len = image_size ** 2
text_len = seq_len - img_seq_len + 1
padding = seq_len - n + 1

# get text and image tokens
Expand All @@ -124,8 +152,22 @@ def forward(self, x, **kwargs):
# merge text and image sequence back together

x_img = rearrange(x_img, 'b h w d -> b (h w) d')
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
return self.fn(x, **kwargs)
x_img = x_img[:, :-padding]
x = torch.cat((x_text, x_img), dim = 1)

if exists(cache):
dummy_top, dummy_left, *_ = x[:, -1].chunk(4, dim=-1)
dummy_top, dummy_left = torch.zeros_like(dummy_top), torch.zeros_like(dummy_left)

q = deque()
x_img = x_img[:, -image_size:]
for _ in range(image_size - x_img.shape[1]):
q.append((dummy_top, dummy_left))
for i in range(x_img.shape[1]):
q.append(x_img[:, i].chunk(4, dim=-1)[:2])
cache[cache_key] = q

return self.fn(x, cache=cache, **kwargs)

# main transformer class

Expand All @@ -152,11 +194,15 @@ def __init__(
rotary_emb = True,
shared_attn_ids = None,
shared_ff_ids = None,
use_static_masks = False,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)

self.seq_len = seq_len
self.image_fmap_size = image_fmap_size

attn_types = default(attn_types, ('full',))
attn_types = cast_tuple(attn_types)
attn_type_layer = islice(cycle(attn_types), depth)
Expand All @@ -173,9 +219,15 @@ def __init__(
elif attn_type == 'sparse':
attn_class = SparseAttention
elif attn_type == 'axial_row':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
if use_static_masks:
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
else:
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_fmap_size, stable = stable)
elif attn_type == 'axial_col':
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
if use_static_masks:
attn_class = partial(Attention, stable = stable, static_mask = self._get_static_mask(attn_type))
else:
attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_fmap_size, stable = stable)
elif attn_type == 'conv_like':
attn_class = partial(SparseConvCausalAttention, seq_len = seq_len, image_size = image_fmap_size, stable = stable)
elif attn_type == 'mlp':
Expand All @@ -199,8 +251,11 @@ def __init__(
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
shared_ff_layers[ff_id] = ff

attn = CachedAs(f'attn_{ind}', attn)

if shift_tokens:
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len))
ff = CachedAs(f'preshift_ff_{ind}', PreShiftToken(ff, image_size = image_fmap_size, seq_len = seq_len))

layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
Expand All @@ -209,7 +264,9 @@ def __init__(

execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
route_all = ((True, True),) * depth
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
'cache': route_all}

self.layers = execute_type(layers, args_route = attn_route_map)

Expand Down Expand Up @@ -245,3 +302,27 @@ def __init__(

def forward(self, x, **kwargs):
return self.layers(x, rotary_pos_emb = self.pos_emb, **kwargs)

def _get_static_mask(self, attn_type):
# In case of attn_type = "axial_{row,col}",
# the sparse implementation is most efficient for training,
# but the full attention with a static mask is most efficient for inference
# since caching is implemented in this case.

img_seq_len = self.image_fmap_size ** 2
text_len = self.seq_len + 1 - img_seq_len

static_mask = torch.zeros(self.seq_len, self.seq_len, dtype=torch.bool)
static_mask[:, :text_len] = True
if attn_type == 'axial_row':
for row in range(self.image_fmap_size):
begin = text_len + row * self.image_fmap_size
end = text_len + (row + 1) * self.image_fmap_size
static_mask[begin:end, begin:end] = True
elif attn_type == 'axial_col':
for col in range(self.image_fmap_size):
begin = text_len + col
static_mask[begin::self.image_fmap_size, begin::self.image_fmap_size] = True
else:
raise ValueError(f'attention type "{attn_type}" can\'t be simulated with a static mask')
return static_mask