Skip to content

Commit

Permalink
allow for customizing number of resnet blocks per stage in space time…
Browse files Browse the repository at this point in the history
… unet
  • Loading branch information
lucidrains committed Dec 13, 2022
1 parent 5a93a28 commit 896a4d0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 14 additions & 10 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,12 +483,13 @@ def __init__(
dim_mult = (1, 2, 4, 8),
self_attns = (False, False, False, True),
temporal_compression = (False, True, True, True),
resnet_block_depths = (2, 2, 2, 2),
attn_dim_head = 64,
attn_heads = 8,
condition_on_timestep = True
):
super().__init__()
assert len(dim_mult) == len(self_attns) == len(temporal_compression)
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
num_layers = len(dim_mult)

dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
Expand Down Expand Up @@ -527,18 +528,19 @@ def __init__(
self.mid_attn = SpatioTemporalAttention(dim = mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim = timestep_cond_dim)

for _, self_attend, (dim_in, dim_out), compress_time in zip(range(num_layers), self_attns, dim_in_out, temporal_compression):
for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth in zip(range(num_layers), self_attns, dim_in_out, temporal_compression, resnet_block_depths):
assert resnet_block_depth >= 1

self.downs.append(mlist([
ResnetBlock(dim_in, dim_out, timestep_cond_dim = timestep_cond_dim),
ResnetBlock(dim_out, dim_out),
mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]),
SpatioTemporalAttention(dim = dim_out, **attn_kwargs) if self_attend else None,
Downsample(dim_out, downsample_time = compress_time)
]))

self.ups.append(mlist([
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim = timestep_cond_dim),
ResnetBlock(dim_in + dim_out, dim_in),
mlist([ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]),
SpatioTemporalAttention(dim = dim_in, **attn_kwargs) if self_attend else None,
Upsample(dim_out, upsample_time = compress_time)

Expand Down Expand Up @@ -576,12 +578,13 @@ def forward(

hiddens = []

for block1, block2, maybe_attention, downsample in self.downs:
x = block1(x, t, enable_time = enable_time)
for init_block, blocks, maybe_attention, downsample in self.downs:
x = init_block(x, t, enable_time = enable_time)

hiddens.append(x.clone())

x = block2(x, enable_time = enable_time)
for block in blocks:
x = block(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
Expand All @@ -594,16 +597,17 @@ def forward(
x = self.mid_attn(x, enable_time = enable_time)
x = self.mid_block2(x, t, enable_time = enable_time)

for block1, block2, maybe_attention, upsample in reversed(self.ups):
for init_block, blocks, maybe_attention, upsample in reversed(self.ups):
x = upsample(x, enable_time = enable_time)

x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block1(x, t, enable_time = enable_time)
x = init_block(x, t, enable_time = enable_time)

x = torch.cat((hiddens.pop() * self.skip_scale, x), dim = 1)

x = block2(x, enable_time = enable_time)
for block in blocks:
x = block(x, enable_time = enable_time)

if exists(maybe_attention):
x = maybe_attention(x, enable_time = enable_time)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 896a4d0

Please sign in to comment.