Skip to content

Commit

Permalink
complete #295
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 28, 2024
1 parent 354a39b commit d0c68fc
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 12 deletions.
85 changes: 74 additions & 11 deletions denoising_diffusion_pytorch/karras_unet_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
factorize_space_time_attn = False,
downsample = False,
downsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
Expand Down Expand Up @@ -247,15 +248,25 @@ def __init__(
self.res_mp_add = MPAdd(t = mp_add_t)

self.attn = None
self.factorized_attn = factorize_space_time_attn

if has_attn:
self.attn = Attention(
attn_kwargs = dict(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)

if factorize_space_time_attn:
self.attn = nn.ModuleList([
Attention(**attn_kwargs, only_space = True),
Attention(**attn_kwargs, only_time = True),
])
else:
self.attn = Attention(**attn_kwargs)

def forward(
self,
x,
Expand Down Expand Up @@ -284,7 +295,13 @@ def forward(
x = self.res_mp_add(x, res)

if exists(self.attn):
x = self.attn(x)
if self.factorized_attn:
attn_space, attn_time = self.attn
x = attn_space(x)
x = attn_time(x)

else:
x = self.attn(x)

return x

Expand All @@ -301,6 +318,7 @@ def __init__(
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
factorize_space_time_attn = False,
upsample = False,
upsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
Expand Down Expand Up @@ -335,15 +353,25 @@ def __init__(
self.res_mp_add = MPAdd(t = mp_add_t)

self.attn = None
self.factorized_attn = factorize_space_time_attn

if has_attn:
self.attn = Attention(
attn_kwargs = dict(
dim = dim_out,
heads = max(ceil(dim_out / attn_dim_head), 2),
dim_head = attn_dim_head,
mp_add_t = attn_res_mp_add_t,
flash = attn_flash
)

if factorize_space_time_attn:
self.attn = nn.ModuleList([
Attention(**attn_kwargs, only_space = True),
Attention(**attn_kwargs, only_time = True),
])
else:
self.attn = Attention(**attn_kwargs)

def forward(
self,
x,
Expand All @@ -369,7 +397,13 @@ def forward(
x = self.res_mp_add(x, res)

if exists(self.attn):
x = self.attn(x)
if self.factorized_attn:
attn_space, attn_time = self.attn
x = attn_space(x)
x = attn_time(x)

else:
x = self.attn(x)

return x

Expand All @@ -383,9 +417,13 @@ def __init__(
dim_head = 64,
num_mem_kv = 4,
flash = False,
mp_add_t = 0.3
mp_add_t = 0.3,
only_space = False,
only_time = False
):
super().__init__()
assert (int(only_space) + int(only_time)) <= 1

self.heads = heads
hidden_dim = dim_head * heads

Expand All @@ -399,20 +437,41 @@ def __init__(

self.mp_add = MPAdd(t = mp_add_t)

self.only_space = only_space
self.only_time = only_time

def forward(self, x):
res, b, c, t, h, w = x, *x.shape
res, orig_shape = x, x.shape
b, c, t, h, w = orig_shape

qkv = self.to_qkv(x)

if self.only_space:
qkv = rearrange(qkv, 'b c t x y -> (b t) c x y')
elif self.only_time:
qkv = rearrange(qkv, 'b c t x y -> (b x y) c t')

qkv = qkv.chunk(3, dim = 1)

qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) t x y -> b h (t x y) c', h = self.heads), qkv)
q, k, v = map(lambda t: rearrange(t, 'b (h c) ... -> b h (...) c', h = self.heads), qkv)

mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = k.shape[0]), self.mem_kv)

mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), self.mem_kv)
k, v = map(partial(torch.cat, dim = -2), ((mk, k), (mv, v)))

q, k, v = map(self.pixel_norm, (q, k, v))

out = self.attend(q, k, v)

out = rearrange(out, 'b h (t x y) d -> b (h d) t x y', t = t, x = h, y = w)
out = rearrange(out, 'b h n d -> b (h d) n')

if self.only_space:
out = rearrange(out, '(b t) c n -> b c (t n)', t = t)
elif self.only_time:
out = rearrange(out, '(b x y) c n -> b c (n x y)', x = h, y = w)

out = out.reshape(orig_shape)

out = self.to_out(out)

return self.mp_add(out, res)
Expand Down Expand Up @@ -446,7 +505,8 @@ def __init__(
attn_res_mp_add_t = 0.3,
resnet_mp_add_t = 0.3,
dropout = 0.1,
self_condition = False
self_condition = False,
factorize_space_time_attn = False
):
super().__init__()

Expand Down Expand Up @@ -576,6 +636,7 @@ def __init__(
has_attn = curr_image_res in attn_res,
upsample = True,
upsample_config = down_and_upsample_config,
factorize_space_time_attn = factorize_space_time_attn,
**block_kwargs
)

Expand All @@ -593,6 +654,7 @@ def __init__(
downsample = True,
downsample_config = down_and_upsample_config,
has_attn = has_attn,
factorize_space_time_attn = factorize_space_time_attn,
**block_kwargs
)

Expand Down Expand Up @@ -777,6 +839,7 @@ def forward(self, x):
),
attn_dim_head = 8,
num_classes = 1000,
factorize_space_time_attn = True # whether to do attention across space and time separately
)

video = torch.randn(2, 4, 32, 64, 64)
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.17'
__version__ = '1.11.0'

0 comments on commit d0c68fc

Please sign in to comment.