From 3c39840b55964ac01f47f39c8340b4bcea52a6d6 Mon Sep 17 00:00:00 2001 From: Hyunho Richard Lee Date: Thu, 14 Jul 2022 19:24:13 -0400 Subject: [PATCH] Adding an option to cache the pretrained model files in S3 --- backend/app.py | 3 +- backend/cache_in_s3.py | 56 ++++++++++++++++++++++++++++ backend/dalle_model.py | 11 ++++-- backend/requirements.txt | 1 + backend/requirements_for_caching.txt | 2 + 5 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 backend/cache_in_s3.py create mode 100644 backend/requirements_for_caching.txt diff --git a/backend/app.py b/backend/app.py index 1af698d1f..ff4e28e64 100644 --- a/backend/app.py +++ b/backend/app.py @@ -21,6 +21,7 @@ parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights") parser.add_argument("--port", type=int, default=8000, help = "backend port") parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full") +parser.add_argument("--s3_bucket", type = str, help = "An S3 bucket that has been prepared with the cache_in_s3.py script") parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk") parser.add_argument("--img_format", type = str.lower, default = "JPEG", help = "Generated images format", choices=['jpeg', 'png']) parser.add_argument("--output_dir", type = str, default = DEFAULT_IMG_OUTPUT_DIR, help = "Customer directory for generated images") @@ -62,7 +63,7 @@ def health_check(): with app.app_context(): - dalle_model = DalleModel(args.model_version) + dalle_model = DalleModel(args.model_version, args.s3_bucket) dalle_model.generate_images("warm-up", 1) print("--> DALL-E Server is up and running!") print(f"--> Model selected - DALL-E {args.model_version}") diff --git a/backend/cache_in_s3.py b/backend/cache_in_s3.py new file mode 100644 index 000000000..23c37e80c --- /dev/null +++ b/backend/cache_in_s3.py @@ -0,0 +1,56 @@ +import argparse +import asyncio +import os + +import boto3 +import wandb + +from consts import (DALLE_MODEL_MEGA, DALLE_MODEL_MEGA_FULL, + DALLE_MODEL_MINI, ModelSize) +from utils import parse_arg_dalle_version + + +def download_pretrained_model_and_cache_in_s3(model_version: str, s3_bucket: str) -> None: + wandb.init(anonymous="must") + + if model_version == ModelSize.MEGA_FULL: + dalle_model = DALLE_MODEL_MEGA_FULL + elif model_version == ModelSize.MEGA: + dalle_model = DALLE_MODEL_MEGA + else: + dalle_model = DALLE_MODEL_MINI + + tmp_dir = "dalle_pretrained_model" + + artifact = wandb.Api().artifact(dalle_model) + artifact.download(tmp_dir) + + s3 = boto3.client("s3", region_name="us-east-2") + for file in os.listdir(tmp_dir): + s3.upload_file(os.path.join(tmp_dir, file), s3_bucket, f"{model_version}/{file}") + + +def download_pretrained_model_from_s3(model_version: str, s3_bucket: str) -> str: + s3 = boto3.client("s3", region_name="us-east-2") + local_dir = os.path.join("/meadowrun/machine_cache", model_version) + os.makedirs(local_dir, exist_ok=True) + for file in s3.list_objects(Bucket=s3_bucket)["Contents"]: + local_path = os.path.join("/meadowrun/machine_cache", file["Key"]) + if file["Key"].startswith(f"{model_version}/") and not os.path.exists(local_path): + print(f"Downloading {file['Key']}") + s3.download_file(s3_bucket, file["Key"], local_path) + + return local_dir + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full") + parser.add_argument("--s3_bucket", required=True) + args = parser.parse_args() + + download_pretrained_model_and_cache_in_s3(str(args.model_version), args.s3_bucket) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backend/dalle_model.py b/backend/dalle_model.py index 485836ab3..ac50bb1bc 100644 --- a/backend/dalle_model.py +++ b/backend/dalle_model.py @@ -1,6 +1,8 @@ import os import random from functools import partial +from typing import Optional +from cache_in_s3 import download_pretrained_model_from_s3 import jax import numpy as np @@ -45,17 +47,20 @@ def p_decode(vqgan, indices, params): class DalleModel: - def __init__(self, model_version: ModelSize) -> None: + def __init__(self, model_version: ModelSize, s3_bucket: Optional[str]) -> None: if model_version == ModelSize.MEGA_FULL: dalle_model = DALLE_MODEL_MEGA_FULL - dtype = jnp.float16 + dtype = jnp.float32 elif model_version == ModelSize.MEGA: dalle_model = DALLE_MODEL_MEGA dtype = jnp.float16 else: dalle_model = DALLE_MODEL_MINI dtype = jnp.float32 - + + if s3_bucket is not None: + # this will now be the path to the local copy of the pretrained model + dalle_model = download_pretrained_model_from_s3(str(model_version), s3_bucket) # Load dalle-mini self.model, params = DalleBart.from_pretrained( diff --git a/backend/requirements.txt b/backend/requirements.txt index 04997de36..e6e338bb6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,3 +9,4 @@ git+https://github.com/huggingface/transformers.git git+https://github.com/patil-suraj/vqgan-jax.git git+https://github.com/borisdayma/dalle-mini.git tqdm +boto3 diff --git a/backend/requirements_for_caching.txt b/backend/requirements_for_caching.txt new file mode 100644 index 000000000..16a3ebd27 --- /dev/null +++ b/backend/requirements_for_caching.txt @@ -0,0 +1,2 @@ +boto3 +wandb