Skip to content

Commit

Permalink
Adding an option to cache the pretrained model files in S3
Browse files Browse the repository at this point in the history
  • Loading branch information
hrichardlee committed Jul 15, 2022
1 parent 85ecaf8 commit 3c39840
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 4 deletions.
3 changes: 2 additions & 1 deletion backend/app.py
Expand Up @@ -21,6 +21,7 @@
parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
parser.add_argument("--port", type=int, default=8000, help = "backend port")
parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
parser.add_argument("--s3_bucket", type = str, help = "An S3 bucket that has been prepared with the cache_in_s3.py script")
parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
parser.add_argument("--img_format", type = str.lower, default = "JPEG", help = "Generated images format", choices=['jpeg', 'png'])
parser.add_argument("--output_dir", type = str, default = DEFAULT_IMG_OUTPUT_DIR, help = "Customer directory for generated images")
Expand Down Expand Up @@ -62,7 +63,7 @@ def health_check():


with app.app_context():
dalle_model = DalleModel(args.model_version)
dalle_model = DalleModel(args.model_version, args.s3_bucket)
dalle_model.generate_images("warm-up", 1)
print("--> DALL-E Server is up and running!")
print(f"--> Model selected - DALL-E {args.model_version}")
Expand Down
56 changes: 56 additions & 0 deletions backend/cache_in_s3.py
@@ -0,0 +1,56 @@
import argparse
import asyncio
import os

import boto3
import wandb

from consts import (DALLE_MODEL_MEGA, DALLE_MODEL_MEGA_FULL,
DALLE_MODEL_MINI, ModelSize)
from utils import parse_arg_dalle_version


def download_pretrained_model_and_cache_in_s3(model_version: str, s3_bucket: str) -> None:
wandb.init(anonymous="must")

if model_version == ModelSize.MEGA_FULL:
dalle_model = DALLE_MODEL_MEGA_FULL
elif model_version == ModelSize.MEGA:
dalle_model = DALLE_MODEL_MEGA
else:
dalle_model = DALLE_MODEL_MINI

tmp_dir = "dalle_pretrained_model"

artifact = wandb.Api().artifact(dalle_model)
artifact.download(tmp_dir)

s3 = boto3.client("s3", region_name="us-east-2")
for file in os.listdir(tmp_dir):
s3.upload_file(os.path.join(tmp_dir, file), s3_bucket, f"{model_version}/{file}")


def download_pretrained_model_from_s3(model_version: str, s3_bucket: str) -> str:
s3 = boto3.client("s3", region_name="us-east-2")
local_dir = os.path.join("/meadowrun/machine_cache", model_version)
os.makedirs(local_dir, exist_ok=True)
for file in s3.list_objects(Bucket=s3_bucket)["Contents"]:
local_path = os.path.join("/meadowrun/machine_cache", file["Key"])
if file["Key"].startswith(f"{model_version}/") and not os.path.exists(local_path):
print(f"Downloading {file['Key']}")
s3.download_file(s3_bucket, file["Key"], local_path)

return local_dir


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
parser.add_argument("--s3_bucket", required=True)
args = parser.parse_args()

download_pretrained_model_and_cache_in_s3(str(args.model_version), args.s3_bucket)


if __name__ == "__main__":
asyncio.run(main())
11 changes: 8 additions & 3 deletions backend/dalle_model.py
@@ -1,6 +1,8 @@
import os
import random
from functools import partial
from typing import Optional
from cache_in_s3 import download_pretrained_model_from_s3

import jax
import numpy as np
Expand Down Expand Up @@ -45,17 +47,20 @@ def p_decode(vqgan, indices, params):


class DalleModel:
def __init__(self, model_version: ModelSize) -> None:
def __init__(self, model_version: ModelSize, s3_bucket: Optional[str]) -> None:
if model_version == ModelSize.MEGA_FULL:
dalle_model = DALLE_MODEL_MEGA_FULL
dtype = jnp.float16
dtype = jnp.float32
elif model_version == ModelSize.MEGA:
dalle_model = DALLE_MODEL_MEGA
dtype = jnp.float16
else:
dalle_model = DALLE_MODEL_MINI
dtype = jnp.float32


if s3_bucket is not None:
# this will now be the path to the local copy of the pretrained model
dalle_model = download_pretrained_model_from_s3(str(model_version), s3_bucket)

# Load dalle-mini
self.model, params = DalleBart.from_pretrained(
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Expand Up @@ -9,3 +9,4 @@ git+https://github.com/huggingface/transformers.git
git+https://github.com/patil-suraj/vqgan-jax.git
git+https://github.com/borisdayma/dalle-mini.git
tqdm
boto3
2 changes: 2 additions & 0 deletions backend/requirements_for_caching.txt
@@ -0,0 +1,2 @@
boto3
wandb

0 comments on commit 3c39840

Please sign in to comment.