Skip to content

Commit

Permalink
Merge pull request #1 from sepal/feat/prefetch_weights
Browse files Browse the repository at this point in the history
Speed up build and start up time
  • Loading branch information
lucataco committed Oct 11, 2023
2 parents 38c88f5 + 64431b9 commit be77a2f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
8 changes: 7 additions & 1 deletion README.md
Expand Up @@ -2,7 +2,13 @@

This is an implementation of the [hotshotco/Hotshot-XL](https://github.com/hotshotco/hotshot-xl) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog)

Run predictions:
## Basic usage

Before running the image, you need to fetch the weights:

cog run python ./scripts/download_weights.py

You can then run the image with:

cog predict -i prompt="go-pro video of a polar bear diving in the ocean, 8k, HD, dslr, nature footage" -i seed=6226

Expand Down
6 changes: 2 additions & 4 deletions cog.yaml
@@ -1,7 +1,7 @@
# Configuration for Cog
build:
gpu: true
cuda: "11.7"

python_version: "3.10"
python_packages:
- "accelerate==0.23.0"
Expand Down Expand Up @@ -67,12 +67,10 @@ build:
- "wandb==0.15.11"
- "zipp==3.17.0"
- "xformers"
- "git+https://github.com/hotshotco/Hotshot-XL"

run:
- apt-get update && apt-get install -y git-lfs ffmpeg
- git lfs install
- git clone https://github.com/hotshotco/Hotshot-XL /Hotshot-XL
- git clone https://huggingface.co/hotshotco/SDXL-512 /Hotshot-XL/SDXL-512


# predict.py defines how predictions are run on your model
Expand Down
28 changes: 6 additions & 22 deletions predict.py
Expand Up @@ -26,20 +26,16 @@
'EulerDiscreteScheduler': EulerDiscreteScheduler,
}

MODEL_NAME = "hotshotco/Hotshot-XL"
MODEL_CACHE = "model-cache"
HOTSHOTXL_CACHE = "hotshot-xl"


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
pipe_line_args = {
"torch_dtype": torch.float16,
"use_safetensors": True
}
self.pipe = HotshotXLPipeline.from_pretrained(
MODEL_NAME,
**pipe_line_args,
cache_dir=MODEL_CACHE
HOTSHOTXL_CACHE,
torch_dtype=torch.float16,
use_safetensors=True
).to('cuda')

def to_pil_images(self, video_frames: torch.Tensor, output_type='pil'):
Expand Down Expand Up @@ -90,20 +86,8 @@ def predict(
video_length = 8
video_duration = 1000
# scheduler = "EulerAncestralDiscreteScheduler"
pipe = self.pipe

device = torch.device("cuda")
pipe_line_args = {
"torch_dtype": torch.float16,
"use_safetensors": True
}
PipelineClass = HotshotXLPipeline

pipe = PipelineClass.from_pretrained(
MODEL_NAME,
**pipe_line_args,
cache_dir=MODEL_CACHE
).to(device)

SchedulerClass = SCHEDULERS[scheduler]
if SchedulerClass is not None:
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
Expand Down
11 changes: 11 additions & 0 deletions scripts/download_weights.py
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
import torch

pipe = HotshotXLPipeline.from_pretrained(
"hotshotco/Hotshot-XL",
torch_dtype=torch.float16,
use_safetensors=True
)

pipe.save_pretrained("./hotshot-xl", safe_serialization=True)

0 comments on commit be77a2f

Please sign in to comment.