Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from inspect import isfunction
import math
from inspect import isfunction

import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from torch import nn, einsum

from ldm.modules.diffusionmodules.util import checkpoint

Expand All @@ -13,7 +14,7 @@ def exists(val):


def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()


def default(val, d):
Expand Down Expand Up @@ -82,14 +83,14 @@ def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
Expand Down Expand Up @@ -131,12 +132,12 @@ def forward(self, x):
v = self.v(h_)

# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)

w_ = w_ * (int(c)**(-0.5))
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)

# attend to values
Expand All @@ -146,7 +147,7 @@ def forward(self, x):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)

return x+h_
return x + h_


class CrossAttention(nn.Module):
Expand Down Expand Up @@ -174,6 +175,7 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
device_type = x.device.type
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
Expand All @@ -188,9 +190,11 @@ def forward(self, x, context=None, mask=None):
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)
if device_type == 'mps': #special case for M1 - disable neonsecret optimization
sim = sim.softmax(dim=-1)
else:
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)

sim = einsum('b i j, b j d -> b i d', sim, v)
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
Expand All @@ -200,7 +204,8 @@ def forward(self, x, context=None, mask=None):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
Expand Down Expand Up @@ -228,6 +233,7 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""

def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
Expand All @@ -243,7 +249,7 @@ def __init__(self, in_channels, n_heads, d_head,

self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
for d in range(depth)]
)

self.proj_out = zero_module(nn.Conv2d(inner_dim,
Expand Down