Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions diffsynth/models/wan_video_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,6 @@ def single_decode(self, hidden_state, device):


def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):

videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
Expand All @@ -1234,11 +1233,18 @@ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(1


def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
if tiled:
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_states, device)
return video
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
Comment on lines +1236 to +1247
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation correctly fixes the batch decoding issue by iterating through the batch. However, the code can be made more readable and slightly more memory-efficient.

  1. The current implementation reassigns hidden_states and then uses hidden_state as a loop variable which is also reassigned inside the loop. This can be confusing. Using new variable names for clarity is better.
  2. It's slightly more memory-efficient to move each hidden_state to CPU inside the loop, rather than creating a new list of all hidden states on CPU at once.

Here is a suggested refactoring that addresses these points.

        videos = []
        for hs in hidden_states:
            hs_batch = hs.to("cpu").unsqueeze(0)
            if tiled:
                video = self.tiled_decode(hs_batch, device, tile_size, tile_stride)
            else:
                video = self.single_decode(hs_batch, device)
            videos.append(video.squeeze(0))
        return torch.stack(videos)



@staticmethod
Expand Down