From 0defbacb09fa99c8c28abbadd0080d775a9aa8f9 Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Thu, 29 Feb 2024 16:30:55 -0500 Subject: [PATCH] runner: Add SVD endpoints to modal --- runner/modal_app.py | 62 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/runner/modal_app.py b/runner/modal_app.py index 683ccaf..7df2251 100644 --- a/runner/modal_app.py +++ b/runner/modal_app.py @@ -7,6 +7,7 @@ use_route_names_as_operation_ids, ) from app.routes import health +import os stub = Stub("livepeer-ai-runner") pipeline_image = ( @@ -34,7 +35,10 @@ @stub.function( - image=downloader_image, volumes={models_path: models_volume}, timeout=3600 + image=downloader_image, + volumes={models_path: models_volume}, + timeout=3600, + secrets=[Secret.from_name("huggingface")], ) def download_model(model_id: str): from huggingface_hub import snapshot_download @@ -50,6 +54,7 @@ def download_model(model_id: str): cache_dir=cache_dir, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, + token=os.environ.get("HF_TOKEN"), ) logger.info(f"Downloaded model {model_id} to volume") models_volume.commit() @@ -58,13 +63,6 @@ def download_model(model_id: str): raise -@stub.cls( - gpu="A10G", - image=pipeline_image, - memory=1024, - volumes={models_path: models_volume}, - container_idle_timeout=3 * 60, -) class Pipeline: def __init__(self, pipeline: str, model_id: str): self.pipeline = pipeline @@ -93,6 +91,28 @@ def predict(self, **kwargs): return self.pipe(**kwargs) +@stub.cls( + gpu="A10G", + image=pipeline_image, + memory=1024, + volumes={models_path: models_volume}, + container_idle_timeout=5 * 60, +) +class A10G_Pipeline(Pipeline): + pass + + +@stub.cls( + gpu="A100", + image=pipeline_image, + memory=1024, + volumes={models_path: models_volume}, + container_idle_timeout=5 * 60, +) +class A100_Pipeline(Pipeline): + pass + + # Wrap Pipeline for dependency injection in the runner FastAPI route class RunnerPipeline: def __init__(self, pipeline: Pipeline): @@ -103,7 +123,7 @@ def __call__(self, **kwargs): return self.pipeline.predict.remote(**kwargs) -def make_api(pipeline: str, model_id: str): +def make_api(pipeline: str, model_id: str, gpu: str = "A10G"): from fastapi import FastAPI config_logging() @@ -112,7 +132,13 @@ def make_api(pipeline: str, model_id: str): app.include_router(health.router) - app.pipeline = RunnerPipeline(Pipeline(pipeline, model_id)) + if gpu == "A10G": + app.pipeline = RunnerPipeline(A10G_Pipeline(pipeline, model_id)) + elif gpu == "A100": + app.pipeline = RunnerPipeline(A100_Pipeline(pipeline, model_id)) + else: + raise Exception(f"invalid gpu value {gpu}") + app.include_router(load_route(pipeline)) use_route_names_as_operation_ids(app) @@ -142,3 +168,19 @@ def text_to_image_sdxl_lightning_8step_api(): @asgi_app() def text_to_image_sdxl_turbo_api(): return make_api("text-to-image", "stabilityai/sdxl-turbo") + + +@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +@asgi_app() +def image_to_video_svd_api(): + return make_api( + "image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt", "A100" + ) + + +@stub.function(image=api_image, secrets=[Secret.from_name("api-auth-token")]) +@asgi_app() +def image_to_video_svd_1_1_api(): + return make_api( + "image-to-video", "stabilityai/stable-video-diffusion-img2vid-xt-1-1", "A100" + )