Skip to content

Commit

Permalink
Choose 16x512x512 model instead of 64x512x512 model
Browse files Browse the repository at this point in the history
  • Loading branch information
jd7h committed Mar 25, 2024
1 parent 2a5bd7a commit 52fdfe3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
1 change: 0 additions & 1 deletion .dockerignore
Expand Up @@ -17,4 +17,3 @@ __pycache__
/venv
pretrained_models/
output*.mp4
.cog
32 changes: 27 additions & 5 deletions predict.py
Expand Up @@ -22,7 +22,7 @@

from opensora.datasets import save_sample
from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.config_utils import parse_configs
from opensora.utils.config_utils import merge_args
from opensora.utils.misc import to_torch_dtype
from opensora.acceleration.parallel_states import set_sequence_parallel_group
from colossalai.cluster import DistCoordinator
Expand Down Expand Up @@ -95,12 +95,30 @@ def setup(self) -> None:
# extra config:
ckpt_path = "pretrained_models/Open-Sora/OpenSora-v1-HQ-16x512x512.pth"
config_file = "configs/opensora/inference/64x512x512.py"
config_file = "configs/opensora/inference/16x512x512.py"

# load config file
self.cfg = cog_config()
self.cfg.model["from_pretrained"] = ckpt_path
if "multi_resolution" not in self.cfg:
self.cfg["multi_resolution"] = False
# option 1: manually
#self.cfg = cog_config()
#self.cfg.model["from_pretrained"] = ckpt_path
#if "multi_resolution" not in self.cfg:
# self.cfg["multi_resolution"] = False

# command line arguments from config_utils
extra_args = Config({
'seed': 42,
'ckpt_path': ckpt_path,
'batch-size': None,
'prompt-path': None,
'save-dir': None,
'num-sampling-steps': None,
'cfg_scale': None,
})

# option 2: use config_utils
self.cfg = Config.fromfile(config_file)
self.cfg = merge_args(self.cfg, args=extra_args, training=False)


# from scripts/inference

Expand Down Expand Up @@ -134,8 +152,12 @@ def setup(self) -> None:
# ======================================================
# 3.1. build model
input_size = (self.cfg.num_frames, *self.cfg.image_size)
print(f"number of frames: {self.cfg.num_frames}, image_size: {self.cfg.image_size}")
print(f"resulting input size: {input_size}")
self.vae = build_module(self.cfg.vae, MODELS)
print("vae", self.vae)
self.latent_size = self.vae.get_latent_size(input_size)
print("latent size:", self.latent_size)
self.text_encoder = build_module(self.cfg.text_encoder, MODELS, device=self.device) # T5 must be fp32
self.model = build_module(
self.cfg.model,
Expand Down

0 comments on commit 52fdfe3

Please sign in to comment.