Skip to content

Commit

Permalink
v2 configs seem to be robust and working well + random seed fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Sep 28, 2023
1 parent b64dfae commit d537b94
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 63 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Bo Dai
[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://animatediff.github.io/)
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/guoyww/AnimateDiff)
[![Replicate](https://replicate.com/zsxkib/animate-diff/badge)](https://replicate.com/zsxkib/animate-diff)


## Features
Expand Down
153 changes: 90 additions & 63 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

# %%
import hashlib
from typing import List
from cog import BasePredictor, Input
from cog import Path as CogPath
Expand Down Expand Up @@ -31,16 +31,19 @@
import csv, pdb, glob
import math
from pathlib import Path
from dataclasses import dataclass


def main(args):
video_paths = []

*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)

# Compute a hash of the config string to get a fixed length string.
config_hash = hashlib.md5(args.config.encode()).hexdigest()

time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
savedir = f"samples/{config_hash}-{time_str}"
os.makedirs(savedir)

if os.path.exists(args.config): # Check if args.config is a file path
Expand Down Expand Up @@ -130,9 +133,7 @@ def main(args):

os.makedirs(os.path.dirname(f"{savedir}/sample/"), exist_ok=True)
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
# video_paths.append(
yield CogPath(f"{savedir}/sample/{sample_idx}-{prompt}.gif")
# )
print(f"save to {savedir}/sample/{prompt}.gif")

sample_idx += 1
Expand All @@ -141,11 +142,6 @@ def main(args):
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)

OmegaConf.save(config, f"{savedir}/config.yaml")
# return video_paths


# %%
from dataclasses import dataclass


@dataclass
Expand All @@ -158,26 +154,47 @@ class Arguments:
H: int = 512


# %%
FAKE_YAML_TEMPLATE = """
{motion_module_lora_type}:
inference_config: "{inference_config}"
motion_module:
- "{motion_module}"
motion_module_lora_configs:
- path: {motion_module_lora_path}
alpha: {alpha}
dreambooth_path: "{dreambooth_path}"
lora_model_path: "{lora_model_path}"
seed: {seed}
steps: {steps}
guidance_scale: {guidance_scale}
prompt:
- "{prompt}"
n_prompt:
- "{n_prompt}"
"""


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
pass

def predict(
self,
inference_config: str = Input(
description="Select Inference Config",
default="configs/inference/inference-v2.yaml",
),
motion_module: str = Input( # Commented out since it's unused
motion_module_type: str = Input(
description="Select a Motion Model",
default="models/Motion_Module/mm_sd_v15_v2.ckpt",
default="mm_sd_v15_v2",
choices=[
# "mm_sd_v15", # TODO: Will only work with v1 inference_config
# "mm_sd_v14", # TODO: Will only work with v1 inference_config
"mm_sd_v15_v2",
],
),
motion_module_lora_type: str = Input(
description="Select a Motion LoRA Path",
description="Select a Motion LoRA type (set `alpha` if anything other than `None` is chosen)",
default="ZoomOut",
choices=[
"None",
"ZoomIn",
"ZoomOut",
"PanLeft",
Expand All @@ -188,66 +205,76 @@ def predict(
"RollingClockwise",
],
),
alpha: float = Input(description="Alpha", default=1.0), # Commented out since it's unused
dreambooth_path: str = Input( # Commented out since it's unused
description="Select a DreamBooth Path",
default="models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors",
),
lora_model_path: str = Input( # Commented out since it's unused
description="Select a LoRA Model Path", default=""
),
seed: int = Input( # Commented out since it's unused
description="Seed (0 = random, maximum: 2147483647)", default=45987230
alpha: float = Input(
description="Only active when `motion_module_lora_type` has been chosen (i.e. `motion_module_lora_type` is not `None`)Alpha",
default=1.0,
),
steps: int = Input( # Commented out since it's unused
description="Number of inference steps", ge=1, le=100, default=25
dreambooth_type: str = Input(
description="Select a DreamBooth type",
default="realisticVisionV20_v20",
choices=[
"realisticVisionV20_v20",
"lyriel_v16",
"majicmixRealistic_v5Preview",
"rcnzCartoon3d_v10",
"toonyou_beta3",
],
),
guidance_scale: float = Input( # Commented out since it's unused
description="Guidance Scale", ge=1, le=10, default=7.5
seed: int = Input(
description="Seed (-1 = random, maximum: 2147483647)", ge=-1, le=2147483647, default=-1
),
prompt: str = Input( # Commented out since it's unused
steps: int = Input(description="Number of inference steps", ge=1, le=100, default=25),
guidance_scale: float = Input(description="Guidance Scale", ge=1, le=10, default=7.5),
prompt: str = Input(
description="Input prompt",
default="photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
),
n_prompt: str = Input( # Commented out since it's unused
n_prompt: str = Input(
description="Negative prompt",
default="blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
),
pretrained_model_path: str = Input(
description="Pretrained Model Path", default="models/StableDiffusion/stable-diffusion-v1-5"
),
# config: str = Input(
# description="Config", default="configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml"
# ),
L: int = Input(description="Length", default=16),
W: int = Input(description="Width", default=512),
H: int = Input(description="Height", default=512),
) -> List[CogPath]:

motion_module_lora_path = f"models/MotionLoRA/v2_lora_{motion_module_lora_type}.ckpt"
fake_yaml = f'''{motion_module_lora_type}:
inference_config: "{inference_config}"
motion_module:
- "{motion_module}"
motion_module_lora_configs:
- path: {motion_module_lora_path}
alpha: {alpha}
dreambooth_path: "{dreambooth_path}"
lora_model_path: "{lora_model_path}"
seed: {seed}
steps: {steps}
guidance_scale: {guidance_scale}
lora_model_path = ""
pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
inference_config = "configs/inference/inference-v2.yaml"
motion_module = f"models/Motion_Module/{motion_module_type}.ckpt"
dreambooth_path = f"models/DreamBooth_LoRA/{dreambooth_type}.safetensors"

if motion_module_lora_type.lower() != "none":
motion_module_lora_path = f"models/MotionLoRA/v2_lora_{motion_module_lora_type}.ckpt"
else:
motion_module_lora_type = "Cog"
motion_module_lora_path = ""

# Replace placeholders directly in the template
config = FAKE_YAML_TEMPLATE.format(
motion_module_lora_type=motion_module_lora_type,
inference_config=inference_config,
motion_module=motion_module,
motion_module_lora_path=motion_module_lora_path,
alpha=alpha,
dreambooth_path=dreambooth_path,
lora_model_path=lora_model_path,
seed=seed,
steps=steps,
guidance_scale=guidance_scale,
prompt=prompt,
n_prompt=n_prompt,
)

prompt:
- "{prompt}"
# Clean up the config if necessary
# HACK: Remove `motion_module_lora_configs` part of fake yaml so we don't use MotionLoRAs
if not motion_module_lora_path:
config = "\n".join(
line
for line in config.split("\n")
if not line.strip().startswith(("motion_module_lora_configs", "path:", "alpha:"))
)

n_prompt:
- "{n_prompt}"'''

config = fake_yaml
args = Arguments(
pretrained_model_path=pretrained_model_path,
inference_config=inference_config,
Expand All @@ -256,5 +283,5 @@ def predict(
W=W,
H=H,
)
print(f"{'-'*30}")

yield from main(args)

0 comments on commit d537b94

Please sign in to comment.