Skip to content

Commit

Permalink
handle greyscale edgecase part deux
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2023
1 parent b4ae2d6 commit 5ef07be
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions magvit2_pytorch/magvit2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,8 +1640,15 @@ def forward(

batch, channels, frames = video.shape[:3]

is_greyscale = channels == 1

assert divisible_by(frames - int(video_contains_first_frame), self.time_downsample_factor), f'number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}'

# handle greyscale

if is_greyscale:
video = repeat(video, 'b 1 ... -> b c ...', c = 3)

# encoder

x = self.encode(video, cond = cond, video_contains_first_frame = video_contains_first_frame)
Expand Down Expand Up @@ -1745,10 +1752,6 @@ def forward(
input_vgg_input = pick_video_frame(video, frame_indices)
recon_vgg_input = pick_video_frame(recon_video, frame_indices)

if channels == 1:
input_vgg_input = repeat(input_vgg_input, 'b 1 h w -> b c h w', c = 3)
recon_vgg_input = repeat(recon_vgg_input, 'b 1 h w -> b c h w', c = 3)

input_vgg_feats = self.vgg(input_vgg_input)
recon_vgg_feats = self.vgg(recon_vgg_input)

Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.33'
__version__ = '0.1.34'

0 comments on commit 5ef07be

Please sign in to comment.