From 441b372c4ab2d6b88d7d82137540f8b9235c7cf9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Nov 2023 12:17:43 -0800 Subject: [PATCH] also handle output conv separately for first frame, depending on whether separate_first_frame_encoding is set to True --- magvit2_pytorch/magvit2_pytorch.py | 70 +++++++++++++++++++----------- magvit2_pytorch/version.py | 2 +- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/magvit2_pytorch/magvit2_pytorch.py b/magvit2_pytorch/magvit2_pytorch.py index c12dd26..307657b 100644 --- a/magvit2_pytorch/magvit2_pytorch.py +++ b/magvit2_pytorch/magvit2_pytorch.py @@ -828,6 +828,11 @@ def forward(self, x): # autoencoder - only best variant here offered, with causal conv 3d +def SameConv2d(dim_in, dim_out, kernel_size, padding_mode = 'constant'): + kernel_size = cast_tuple(kernel_size, 2) + padding = [k // 2 for k in kernel_size] + return nn.Conv2d(dim_in, dim_out, kernel_size = kernel_size, padding = padding, padding_mode = padding_mode) + class CausalConv3d(Module): @beartype def __init__( @@ -1047,13 +1052,12 @@ def __init__( # whether to encode the first frame separately or not - self.conv_first_frame = nn.Identity() + self.conv_in_first_frame = nn.Identity() + self.conv_out_first_frame = nn.Identity() if separate_first_frame_encoding: - spatial_kernel_size = input_conv_kernel_size[-2:] - spatial_same_padding = tuple(k // 2 for k in spatial_kernel_size) - - self.conv_first_frame = nn.Conv2d(channels, init_dim, spatial_kernel_size, padding = spatial_same_padding, padding_mode = pad_mode) + self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:], padding_mode = pad_mode) + self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:], padding_mode = pad_mode) self.separate_first_frame_encoding = separate_first_frame_encoding @@ -1379,7 +1383,8 @@ def init_and_load_from(cls, path, strict = True): def parameters(self): return [ *self.conv_in.parameters(), - *self.conv_first_frame.parameters(), + *self.conv_in_first_frame.parameters(), + *self.conv_out_first_frame.parameters(), *self.conv_out.parameters(), *self.encoder_layers.parameters(), *self.decoder_layers.parameters(), @@ -1447,6 +1452,11 @@ def encode( ): encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + # whether to pad video or not + + if video_contains_first_frame: + video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2) + # conditioning, if needed assert (not self.has_cond) or exists(cond), '`cond` must be passed into tokenizer forward method since conditionable layers were specified' @@ -1462,12 +1472,12 @@ def encode( if encode_first_frame_separately: first_frame, video = video[:, :, 0], video[:, :, 1:] - xf = self.conv_first_frame(first_frame) + xff = self.conv_in_first_frame(first_frame) x = self.conv_in(video) if encode_first_frame_separately: - x, _ = pack([xf, x], 'b c * h w') + x, _ = pack([xff, x], 'b c * h w') # encoder layers @@ -1500,19 +1510,19 @@ def decode_from_code_indices( codes = rearrange(codes, 'b (f h w) -> b f h w', h = self.fmap_size, w = self.fmap_size) quantized = self.quantizers.indices_to_codes(codes) - out = self.decode(quantized, cond = cond) - if video_contains_first_frame: - out = out[:, :, self.time_padding:] - - return out + return self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame) @beartype def decode( self, quantized: Tensor, - cond: Optional[Tensor] = None + cond: Optional[Tensor] = None, + video_contains_first_frame = True ): + should_pad_video_for_first_frame = video_contains_first_frame and not self.separate_first_frame_encoding + decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + batch = quantized.shape[0] # conditioning, if needed @@ -1540,7 +1550,25 @@ def decode( # to pixels - return self.conv_out(x) + if decode_first_frame_separately: + + padding_idx = 0 if not video_contains_first_frame else self.time_padding + left_pad, xff, x = x[:, :, :padding_idx], x[:, :, padding_idx], x[:, :, (padding_idx + 1):] + + out = self.conv_out(x) + outff = self.conv_out_first_frame(xff) + + video, _ = pack([outff, out], 'b c * h w') + + else: + video = self.conv_out(x) + + # if video were padded, remove padding + + if video_contains_first_frame: + video = video[:, :, self.time_padding:] + + return video @torch.no_grad() def tokenize(self, video): @@ -1579,14 +1607,9 @@ def forward( assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}' - # pad the time, accounting for total time downsample factor, so that images can be trained independently - - if video_contains_first_frame: - padded_video = pad_at_dim(video, (self.time_padding, 0), value = 0., dim = 2) - # encoder - x = self.encode(padded_video, cond = cond, video_contains_first_frame = video_contains_first_frame) + x = self.encode(video, cond = cond, video_contains_first_frame = video_contains_first_frame) # lookup free quantization @@ -1603,10 +1626,7 @@ def forward( # decoder - padded_recon_video = self.decode(quantized, cond = cond) - - if video_contains_first_frame: - recon_video = padded_recon_video[:, :, self.time_padding:] + recon_video = self.decode(quantized, cond = cond, video_contains_first_frame = video_contains_first_frame) if return_codes: return codes, recon_video diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index b5fb6c4..cdf16bc 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.1.23' +__version__ = '0.1.24'