Skip to content

Commit

Permalink
Add output format options (#54) (#91)
Browse files Browse the repository at this point in the history
refactor postprocessing

Co-authored-by: zappityzap <128872140+zappityzap@users.noreply.github.com>
  • Loading branch information
continue-revolution and zappityzap committed Sep 16, 2023
1 parent a94ab8c commit e7b5eea
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 43 deletions.
34 changes: 4 additions & 30 deletions scripts/animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scripts.animatediff_logger import logger_animatediff as logger
from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup
from scripts.animatediff_mm import mm_animatediff as motion_module
from scripts.animatediff_output import AnimateDiffOutput


script_dir = scripts.basedir()
Expand All @@ -25,14 +26,11 @@ def show(self, is_img2img):

def ui(self, is_img2img):
model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(script_dir, "model"))
ui_group = AnimateDiffUiGroup()
return (ui_group.render(is_img2img, model_dir),)
return (AnimateDiffUiGroup().render(is_img2img, model_dir),)

def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
if params.enable:
logger.info(f"AnimateDiff process start with video Max frames {params.video_length}, FPS {params.fps}, duration {params.video_length/params.fps}, motion module {params.model}.")
assert params.video_length > 0 and params.fps > 0, "Video length and FPS should be positive."
p.batch_size = params.video_length
params.set_p(p)
motion_module.inject(p.sd_model, params.model)

def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs):
Expand All @@ -47,36 +45,12 @@ def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiff
def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
if params.enable:
motion_module.restore(p.sd_model)
video_paths = []
logger.info("Merging images into GIF.")
from pathlib import Path
Path(f"{p.outpath_samples}/AnimateDiff").mkdir(exist_ok=True, parents=True)
for i in range(res.index_of_first_image, len(res.images), params.video_length):
video_list = res.images[i:i+params.video_length]

if 0 in params.reverse:
video_list_reverse = video_list[::-1]
if 1 in params.reverse:
video_list_reverse.pop(0)
if 2 in params.reverse:
video_list_reverse.pop(-1)
video_list = video_list + video_list_reverse

seq = images.get_next_sequence_number(f"{p.outpath_samples}/AnimateDiff", "")
filename = f"{seq:05}-{res.seed}"
video_path = f"{p.outpath_samples}/AnimateDiff/{filename}.gif"
video_paths.append(video_path)
imageio.mimsave(video_path, video_list, duration=(1/params.fps), loop=params.loop_number)
res.images = video_paths
AnimateDiffOutput().output(p, res, params)
logger.info("AnimateDiff process end.")

def on_ui_settings():
section = ('animatediff', "AnimateDiff")
shared.opts.add_option("animatediff_model_path", shared.OptionInfo(os.path.join(script_dir, "model"), "Path to save AnimateDiff motion modules", gr.Textbox, section=section))
shared.opts.add_option("animatediff_hack_gn", shared.OptionInfo(
True, "Check if you want to hack GroupNorm. By default, V1 hacks GroupNorm, which avoids a performance degradation. "
"If you choose not to hack GroupNorm for V1, you will be able to use this extension in img2img in all cases, but the generated GIF will have flickers. "
"V2 does not hack GroupNorm, so that this option will not influence v2 inference.", gr.Checkbox, section=section))


script_callbacks.on_ui_settings(on_ui_settings)
Expand Down
8 changes: 4 additions & 4 deletions scripts/animatediff_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
unet = sd_model.model.diffusion_model
self._load(model_name)
TimestepEmbedSequential.forward = mm_tes_forward
if shared.opts.data.get("animatediff_hack_gn", False) and (not self.mm.using_v2):
if not self.mm.using_v2:
logger.info(f"Hacking GroupNorm32 forward function.")
def groupnorm32_mm_forward(self, x):
x = rearrange(x, '(b f) c h w -> b c f h w', b=2)
Expand Down Expand Up @@ -112,7 +112,7 @@ def restore(self, sd_model):
if self.mm.using_v2:
logger.info(f"Removing motion module from SD1.5 UNet middle block.")
unet.middle_block.pop(-2)
if shared.opts.data.get("animatediff_hack_gn", False) and (not self.mm.using_v2):
if not self.mm.using_v2:
logger.info(f"Restoring GroupNorm32 forward function.")
GroupNorm32.forward = gn32_original_forward
TimestepEmbedSequential.forward = tes_original_forward
Expand All @@ -130,10 +130,10 @@ def _set_ddim_alpha(self, sd_model):
alphas_cumprod_prev = torch.cat(
(torch.tensor([1.0], dtype=torch.float32, device=device), alphas_cumprod[:-1]))
self.prev_beta = sd_model.betas
sd_model.betas = betas
self.prev_alpha_cumprod = sd_model.alphas_cumprod
sd_model.alphas_cumprod = alphas_cumprod
self.prev_alpha_cumprod_prev = sd_model.alphas_cumprod_prev
sd_model.betas = betas
sd_model.alphas_cumprod = alphas_cumprod
sd_model.alphas_cumprod_prev = alphas_cumprod_prev

def _restore_ddim_alpha(self, sd_model):
Expand Down
74 changes: 74 additions & 0 deletions scripts/animatediff_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import imageio
from pathlib import Path

from modules import images
from modules.processing import StableDiffusionProcessing, Processed

from scripts.animatediff_logger import logger_animatediff as logger
from scripts.animatediff_ui import AnimateDiffProcess

class AnimateDiffOutput:

def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
video_paths = []
logger.info("Merging images into GIF.")
Path(f"{p.outpath_samples}/AnimateDiff").mkdir(exist_ok=True, parents=True)
for i in range(res.index_of_first_image, len(res.images), params.video_length):
video_list = res.images[i:i+params.video_length]

seq = images.get_next_sequence_number(f"{p.outpath_samples}/AnimateDiff", "")
filename = f"{seq:05}-{res.seed}"
video_path_prefix = f"{p.outpath_samples}/AnimateDiff/{filename}."

video_list = self._add_reverse(params, video_list)
video_paths += self._save(params, video_list, video_path_prefix, res, i)
if len(video_paths) > 0:
res.images = video_paths

def _add_reverse(self, params: AnimateDiffProcess, video_list: list):
if 0 in params.reverse:
video_list_reverse = video_list[::-1]
if 1 in params.reverse:
video_list_reverse.pop(0)
if 2 in params.reverse:
video_list_reverse.pop(-1)
return video_list + video_list_reverse
return video_list

def _save(self, params: AnimateDiffProcess, video_list: list, video_path_prefix: str, res: Processed, index: int):
video_paths = []
if "GIF" in params.format:
video_path_gif = video_path_prefix + "gif"
video_paths.append(video_path_gif)
imageio.mimsave(video_path_gif, video_list, duration=(1/params.fps), loop=params.loop_number)
if "Optimize GIF" in params.format:
self._optimize_gif(video_path_gif)
if "MP4" in params.format:
video_path_mp4 = video_path_prefix + "mp4"
video_paths.append(video_path_mp4)
imageio.mimsave(video_path_mp4, video_list, fps=params.fps)
if "TXT" in params.format and res.images[index].info is not None:
video_path_txt = video_path_prefix + "txt"
self._save_txt(params, video_path_txt, res, index)
return video_paths

def _optimize_gif(self, video_path: str):
try:
import pygifsicle
except ImportError:
from launch import run_pip
run_pip("install pygifsicle", "sd-webui-animatediff GIF optimization requirement: pygifsicle")
import pygifsicle
finally:
try:
pygifsicle.optimize(video_path)
except FileNotFoundError:
logger.warn("gifsicle not found, required for optimized GIFs, try: apt install gifsicle")

def _save_txt(self, params: AnimateDiffProcess, video_path: str, res: Processed, i: int):
res.images[i].info['motion_module'] = params.model
res.images[i].info['video_length'] = params.video_length
res.images[i].info['fps'] = params.fps
res.images[i].info['loop_number'] = params.loop_number
with open(video_path, "w", encoding="utf8") as file:
file.write(f"{res.images[i].info}\n")
29 changes: 20 additions & 9 deletions scripts/animatediff_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@ def get_block_name(self):


class AnimateDiffProcess:

def __init__(
self,
enable=False,
loop_number=0,
video_length=16,
fps=8,
model="mm_sd_v15.ckpt",
model="mm_sd_v15_v2.ckpt",
format=["GIF", "PNG"],
reverse=[]):
self.enable = enable
self.loop_number = loop_number
self.video_length = video_length
self.fps = fps
self.model = model
self.format = format
self.reverse = reverse

def get_list(self):
Expand All @@ -37,9 +40,20 @@ def get_list(self):
self.video_length,
self.fps,
self.model,
self.format,
self.reverse
]

def _check(self):
assert self.video_length > 0 and self.fps > 0, "Video length and FPS should be positive."
assert not set(["GIF", "MP4", "PNG"]).isdisjoint(self.format), "At least one saving format should be selected."

def set_p(self, p):
self._check()
p.batch_size = self.video_length
if "PNG" not in self.format:
p.do_not_save_samples = True


class AnimateDiffUiGroup:
txt2img_submit_button = None
Expand Down Expand Up @@ -73,15 +87,12 @@ def refresh_models(*inputs):
self.params.fps = gr.Number(value=8, label="Frames per second (FPS)", precision=0)
self.params.loop_number = gr.Number(minimum=0, value=0, label="Display loop number (0 = infinite loop)", precision=0)
with gr.Row():
self.params.format = gr.CheckboxGroup(
choices=["GIF", "MP4", "PNG", "TXT", "Optimize GIF"],
label="Save", type="value", value=["GIF", "PNG"])
self.params.reverse = gr.CheckboxGroup(
choices=[
"Add Reverse Frame",
"Remove head",
"Remove tail"
],
label="Reverse",
type="index"
)
choices=["Add Reverse Frame", "Remove head", "Remove tail"],
label="Reverse",type="index")
with gr.Row():
unload = gr.Button(value="Move motion module to CPU (default if lowvram)")
remove = gr.Button(value="Remove motion module from any memory")
Expand Down

0 comments on commit e7b5eea

Please sign in to comment.