forked from saharmor/dalle-playground
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding an option to cache the pretrained model files in S3
- Loading branch information
1 parent
85ecaf8
commit 3c39840
Showing
5 changed files
with
69 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import argparse | ||
import asyncio | ||
import os | ||
|
||
import boto3 | ||
import wandb | ||
|
||
from consts import (DALLE_MODEL_MEGA, DALLE_MODEL_MEGA_FULL, | ||
DALLE_MODEL_MINI, ModelSize) | ||
from utils import parse_arg_dalle_version | ||
|
||
|
||
def download_pretrained_model_and_cache_in_s3(model_version: str, s3_bucket: str) -> None: | ||
wandb.init(anonymous="must") | ||
|
||
if model_version == ModelSize.MEGA_FULL: | ||
dalle_model = DALLE_MODEL_MEGA_FULL | ||
elif model_version == ModelSize.MEGA: | ||
dalle_model = DALLE_MODEL_MEGA | ||
else: | ||
dalle_model = DALLE_MODEL_MINI | ||
|
||
tmp_dir = "dalle_pretrained_model" | ||
|
||
artifact = wandb.Api().artifact(dalle_model) | ||
artifact.download(tmp_dir) | ||
|
||
s3 = boto3.client("s3", region_name="us-east-2") | ||
for file in os.listdir(tmp_dir): | ||
s3.upload_file(os.path.join(tmp_dir, file), s3_bucket, f"{model_version}/{file}") | ||
|
||
|
||
def download_pretrained_model_from_s3(model_version: str, s3_bucket: str) -> str: | ||
s3 = boto3.client("s3", region_name="us-east-2") | ||
local_dir = os.path.join("/meadowrun/machine_cache", model_version) | ||
os.makedirs(local_dir, exist_ok=True) | ||
for file in s3.list_objects(Bucket=s3_bucket)["Contents"]: | ||
local_path = os.path.join("/meadowrun/machine_cache", file["Key"]) | ||
if file["Key"].startswith(f"{model_version}/") and not os.path.exists(local_path): | ||
print(f"Downloading {file['Key']}") | ||
s3.download_file(s3_bucket, file["Key"], local_path) | ||
|
||
return local_dir | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full") | ||
parser.add_argument("--s3_bucket", required=True) | ||
args = parser.parse_args() | ||
|
||
download_pretrained_model_and_cache_in_s3(str(args.model_version), args.s3_bucket) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
boto3 | ||
wandb |