Skip to content

Commit

Permalink
also handle output conv separately for first frame, depending on whet…
Browse files Browse the repository at this point in the history
…her separate_first_frame_encoding is set to True
  • Loading branch information
lucidrains committed Nov 16, 2023
1 parent e3bab3d commit 441b372
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
70 changes: 45 additions & 25 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,11 @@ def forward(self, x):

# autoencoder - only best variant here offered, with causal conv 3d

def SameConv2d(dim_in, dim_out, kernel_size, padding_mode = 'constant'):
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, padding_mode = padding_mode)

class CausalConv3d(Module):
@beartype
def __init__(
Expand Down Expand Up @@ -1047,13 +1052,12 @@ def __init__(

# whether to encode the first frame separately or not

self.conv_first_frame = nn.Identity()
self.conv_in_first_frame = nn.Identity()
self.conv_out_first_frame = nn.Identity()

if separate_first_frame_encoding:
spatial_kernel_size = input_conv_kernel_size[-2:]
spatial_same_padding = tuple(k // 2 for k in spatial_kernel_size)

self.conv_first_frame = nn.Conv2d(channels, init_dim, spatial_kernel_size, padding = spatial_same_padding, padding_mode = pad_mode)
self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:], padding_mode = pad_mode)
self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:], padding_mode = pad_mode)

self.separate_first_frame_encoding = separate_first_frame_encoding

Expand Down Expand Up @@ -1379,7 +1383,8 @@ def init_and_load_from(cls, path, strict = True):
def parameters(self):
return [
*self.conv_in.parameters(),
*self.conv_first_frame.parameters(),
*self.conv_in_first_frame.parameters(),
*self.conv_out_first_frame.parameters(),
*self.conv_out.parameters(),
*self.encoder_layers.parameters(),
*self.decoder_layers.parameters(),
Expand Down Expand Up @@ -1447,6 +1452,11 @@ def encode(
):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame

# whether to pad video or not

if video_contains_first_frame:
video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)

# conditioning, if needed

assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified'
Expand All @@ -1462,12 +1472,12 @@ def encode(

if encode_first_frame_separately:
first_frame, video = video[:, :, 0], video[:, :, 1:]
xf = self.conv_first_frame(first_frame)
xff = self.conv_in_first_frame(first_frame)

x = self.conv_in(video)

if encode_first_frame_separately:
x, _ = pack([xf, x], 'b c * h w')
x, _ = pack([xff, x], 'b c * h w')

# encoder layers

Expand Down Expand Up @@ -1500,19 +1510,19 @@ def decode_from_code_indices(
codes = rearrange(codes, 'b (f h w) -> b f h w', h = self.fmap_size, w = self.fmap_size)

quantized = self.quantizers.indices_to_codes(codes)
out = self.decode(quantized, cond = cond)

if video_contains_first_frame:
out = out[:, :, self.time_padding:]

return out
return self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame)

@beartype
def decode(
self,
quantized: Tensor,
cond: Optional[Tensor] = None
cond: Optional[Tensor] = None,
video_contains_first_frame = True
):
should_pad_video_for_first_frame = video_contains_first_frame and not self.separate_first_frame_encoding
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame

batch = quantized.shape[0]

# conditioning, if needed
Expand Down Expand Up @@ -1540,7 +1550,25 @@ def decode(

# to pixels

return self.conv_out(x)
if decode_first_frame_separately:

padding_idx = 0 if not video_contains_first_frame else self.time_padding
left_pad, xff, x = x[:, :, :padding_idx], x[:, :, padding_idx], x[:, :, (padding_idx + 1):]

out = self.conv_out(x)
outff = self.conv_out_first_frame(xff)

video, _ = pack([outff, out], 'b c * h w')

else:
video = self.conv_out(x)

# if video were padded, remove padding

if video_contains_first_frame:
video = video[:, :, self.time_padding:]

return video

@torch.no_grad()
def tokenize(self, video):
Expand Down Expand Up @@ -1579,14 +1607,9 @@ def forward(

assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}'

# pad the time, accounting for total time downsample factor, so that images can be trained independently

if video_contains_first_frame:
padded_video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2)

# encoder

x = self.encode(padded_video, cond = cond, video_contains_first_frame = video_contains_first_frame)
x = self.encode(video, cond = cond, video_contains_first_frame = video_contains_first_frame)

# lookup free quantization

Expand All @@ -1603,10 +1626,7 @@ def forward(

# decoder

padded_recon_video = self.decode(quantized, cond = cond)

if video_contains_first_frame:
recon_video = padded_recon_video[:, :, self.time_padding:]
recon_video = self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame)

if return_codes:
return codes, recon_video
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.23'
__version__ = '0.1.24'

0 comments on commit 441b372

Please sign in to comment.