Skip to content

Commit

Permalink
move memory savings
Browse files Browse the repository at this point in the history
  • Loading branch information
Luis committed Feb 24, 2024
1 parent 8653547 commit 99696b3
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,19 @@ def setup(self) -> None:
torch_dtype=torch.float16,
cache_dir=MODEL_CACHE
).to("cuda")
# enable memory savings
self.pipe.enable_vae_slicing()
self.pipe.enable_model_cpu_offload()

@torch.inference_mode()
def predict(
self,
video: Path = Input(description="Input video"),
prompt: str = Input(description="Prompt for the model", default="panda playing a guitar, on a boat, in the ocean, high quality"),
negative_prompt: str = Input(description="Negative prompt for the model", default="bad quality, worse quality"),
negative_prompt: str = Input(description="Negative prompt for the model", default="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"),
guidance_scale: float = Input(description="Guidance scale for the model", default=7.5),
num_inference_steps: int = Input(description="Number of inference steps", default=25),
strength: float = Input(description="Strength of the initial image", default=0.5),
strength: float = Input(description="Strength of the initial image", default=0.6),
seed: int = Input(description="Random seed, leave blank to randomize the seed", default=None),
) -> Path:
"""Run a single prediction on the model"""
Expand All @@ -85,10 +88,6 @@ def predict(
)
self.pipe.scheduler = scheduler

# enable memory savings
self.pipe.enable_vae_slicing()
self.pipe.enable_model_cpu_offload()

video = load_video(str(video))
output = self.pipe(
video = video,
Expand Down

0 comments on commit 99696b3

Please sign in to comment.