Skip to content

Commit

Permalink
A different way of handling multiple images passed to SVD.
Browse files Browse the repository at this point in the history
Previously when a list of 3 images [0, 1, 2] was used for a 6 frame video
they were concated like this:
[0, 1, 2, 0, 1, 2]

now they are concated like this:
[0, 0, 1, 1, 2, 2]
  • Loading branch information
comfyanonymous committed Dec 3, 2023
1 parent b2517b4 commit 61a123a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def extra_conds(self, **kwargs):
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])

out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

Expand Down
20 changes: 20 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,26 @@ def repeat_to_batch_size(tensor, batch_size):
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor

def resize_to_batch_size(tensor, batch_size):
in_batch_size = tensor.shape[0]
if in_batch_size == batch_size:
return tensor

if batch_size <= 1:
return tensor[:batch_size]

output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]

return output

def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
Expand Down

0 comments on commit 61a123a

Please sign in to comment.