Skip to content

Commit

Permalink
runner: Add SVD endpoints to modal
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 29, 2024
1 parent 1d9b52b commit 0defbac
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions runner/modal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use_route_names_as_operation_ids,
)
from app.routes import health
import os

stub = Stub("livepeer-ai-runner")
pipeline_image = (
Expand Down Expand Up @@ -34,7 +35,10 @@


@stub.function(
image=downloader_image, volumes={models_path: models_volume}, timeout=3600
image=downloader_image,
volumes={models_path: models_volume},
timeout=3600,
secrets=[Secret.from_name("huggingface")],
)
def download_model(model_id: str):
from huggingface_hub import snapshot_download
Expand All @@ -50,6 +54,7 @@ def download_model(model_id: str):
cache_dir=cache_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
token=os.environ.get("HF_TOKEN"),
)
logger.info(f"Downloaded model {model_id} to volume")
models_volume.commit()
Expand All @@ -58,13 +63,6 @@ def download_model(model_id: str):
raise


@stub.cls(
gpu="A10G",
image=pipeline_image,
memory=1024,
volumes={models_path: models_volume},
container_idle_timeout=3 * 60,
)
class Pipeline:
def __init__(self, pipeline: str, model_id: str):
self.pipeline = pipeline
Expand Down Expand Up @@ -93,6 +91,28 @@ def predict(self, **kwargs):
return self.pipe(**kwargs)


@stub.cls(
gpu="A10G",
image=pipeline_image,
memory=1024,
volumes={models_path: models_volume},
container_idle_timeout=5 * 60,
)
class A10G_Pipeline(Pipeline):
pass


@stub.cls(
gpu="A100",
image=pipeline_image,
memory=1024,
volumes={models_path: models_volume},
container_idle_timeout=5 * 60,
)
class A100_Pipeline(Pipeline):
pass


# Wrap Pipeline for dependency injection in the runner FastAPI route
class RunnerPipeline:
def __init__(self, pipeline: Pipeline):
Expand All @@ -103,7 +123,7 @@ def __call__(self, **kwargs):
return self.pipeline.predict.remote(**kwargs)


def make_api(pipeline: str, model_id: str):
def make_api(pipeline: str, model_id: str, gpu: str = "A10G"):
from fastapi import FastAPI

config_logging()
Expand All @@ -112,7 +132,13 @@ def make_api(pipeline: str, model_id: str):

app.include_router(health.router)

app.pipeline = RunnerPipeline(Pipeline(pipeline, model_id))
if gpu == "A10G":
app.pipeline = RunnerPipeline(A10G_Pipeline(pipeline, model_id))
elif gpu == "A100":
app.pipeline = RunnerPipeline(A100_Pipeline(pipeline, model_id))
else:
raise Exception(f"invalid gpu value {gpu}")

app.include_router(load_route(pipeline))

use_route_names_as_operation_ids(app)
Expand Down Expand Up @@ -142,3 +168,19 @@ def text_to_image_sdxl_lightning_8step_api():
@asgi_app()
def text_to_image_sdxl_turbo_api():
return make_api("text-to-image", "stabilityai/sdxl-turbo")


@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")])
@asgi_app()
def image_to_video_svd_api():
return make_api(
"image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt", "A100"
)


@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")])
@asgi_app()
def image_to_video_svd_1_1_api():
return make_api(
"image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt-1-1", "A100"
)

0 comments on commit 0defbac

Please sign in to comment.