Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache Refactor and Improvements #710

Merged
merged 41 commits into from
Nov 6, 2023
Merged

Cache Refactor and Improvements #710

merged 41 commits into from
Nov 6, 2023

Conversation

varunshenoy
Copy link
Contributor

@varunshenoy varunshenoy commented Oct 26, 2023

This PR adds the following features:

  • Caching with public S3 buckets
  • A refactored cache_warmer.py
  • Individual trusses can contain cached files from different cloud stores
  • Update docs to include information about S3
  • Alias model_cache to hf_cache.
  • models are now saved in app/model_cache instead of app/hf_cache

Tested the following on dev:

  • public gcs
  • private gcs
  • public s3
  • private s3
  • model_cache

@varunshenoy varunshenoy changed the title Varun/cache refactor [WIP] Cache Refactor and Improvements Oct 26, 2023
@@ -364,15 +371,17 @@ def create_vllm_build_dir(
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")

data_dir = build_dir / "data"
credentials_file = data_dir / "service_account.json"
gcs_credentials_file = data_dir / "service_account.json"
s3_credentials_file = data_dir / "s3_credentials.json"
dockerfile_content = dockerfile_template.render(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to get very unwieldy — is there a better way to format it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what specifically did you have in mind here?

@varunshenoy varunshenoy marked this pull request as ready for review October 28, 2023 01:12
@varunshenoy varunshenoy changed the title [WIP] Cache Refactor and Improvements Cache Refactor and Improvements Oct 28, 2023
Copy link
Collaborator

@squidarth squidarth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good, thanks for making these changes. My main high-level feedback is that while we're at this, we should rename a couple other things:

  • HuggingFaceCache -> ModelCache
  • HuggingFaceCache.repo_id -> ModelCache.path

wdyt?

try:
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file)
proc.wait()
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should almost never do except Exception. Imagine we mispelled proc.wait() as proc.wai. This would throw an attribute not found exception, and it would be very hard to figure out. Let's instead enumerate the network-related errors that could happen here.

# open the json file
with open(file_path, "r") as f:
data = json.load(f)
class RepositoryFile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this an abstract base class (https://docs.python.org/3/library/abc.html)

self.is_private = True

@staticmethod
def create(repo_name, file_name, revision_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I tend to prefer from over create

@@ -35,6 +36,23 @@ def _download_from_url_using_b10cp(
)


def parse_s3_service_account_file(file_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add input & output types

def parse_s3_service_account_file(file_path):
# open the json file
with open(file_path, "r") as f:
data = json.load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider using something like python dataclass or pydantic to define the type.

try:
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file)
proc.wait()
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below my note about using except Exception as e

cache_dir = Path(f"/app/model_cache/{self.bucket_name}")
cache_dir.mkdir(parents=True, exist_ok=True)

dst_file = Path(f"{cache_dir}/{self.file_name}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dst_file = cache_dir / self.file_name or dst_file = cache_dir / Path(self.file_name) should work here

aws_secret_access_key = data["aws_secret_access_key"]
aws_region = data["aws_region"]
class GCSFile(RepositoryFile):
def connect(self, key_file="/app/data/service_account.json"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the point of this connect function? could we do all of this as a part of download?

@@ -502,6 +503,12 @@ def from_dict(d):
def from_yaml(yaml_path: Path):
with yaml_path.open() as yaml_file:
raw_data = yaml.safe_load(yaml_file) or {}
if "hf_cache" in raw_data:
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use logger instead of warnings here

def test_null_hf_cache_key():
config_yaml_dict = {"hf_cache": None}
def test_null_model_cache_key():
config_yaml_dict = {"model_cache": None}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case I missed this, could we have a case where there's a yaml file with the key "hf_cache" and check that it parses out correctly?

@varunshenoy
Copy link
Contributor Author

overall looks good, thanks for making these changes. My main high-level feedback is that while we're at this, we should rename a couple other things:

  • HuggingFaceCache -> ModelCache
  • HuggingFaceCache.repo_id -> ModelCache.path

wdyt?

Agree with HuggingFaceCache -> ModelCache, but think it might be better to keep repo_id since most folks are caching from Hugging Face anyways.

@squidarth
Copy link
Collaborator

Agree with HuggingFaceCache -> ModelCache, but think it might be better to keep repo_id since most folks are caching from Hugging Face anyways.

k - I think that's fine. Still technically makes sense with gcs & s3

@@ -364,15 +370,17 @@ def create_vllm_build_dir(
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")

data_dir = build_dir / "data"
credentials_file = data_dir / "service_account.json"
gcs_credentials_file = data_dir / "service_account.json"
s3_credentials_file = data_dir / "s3_credentials.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we move these key files to constants?

@@ -253,13 +257,13 @@ def fetch_files_to_cache(cached_files: list, repo_id: str, filtered_repo_files:
repo_id = f"gs://{bucket_name}"

for filename in filtered_repo_files:
cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}")
cached_files.append(f"/app/model_cache/{bucket_name}/{filename}")
elif repo_id.startswith("s3://"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use teh new nice clases that you made for this instead of big if statement?

@@ -364,15 +371,17 @@ def create_vllm_build_dir(
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")

data_dir = build_dir / "data"
credentials_file = data_dir / "service_account.json"
gcs_credentials_file = data_dir / "service_account.json"
s3_credentials_file = data_dir / "s3_credentials.json"
dockerfile_content = dockerfile_template.render(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what specifically did you have in mind here?

# Create S3 Client
bucket_name, _ = split_path(repo_name, prefix="s3://")

key_file = "/app/data/s3_credentials.jso"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo? Also let's move this path to a constant

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varunshenoy how did this work before?

except ValueError as value_error:
raise RuntimeError(f"Failure due to an error: {value_error}")

except Exception as general_error:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to just let this issue throw, we don't need to catch it

)
except FileNotFoundError:
except Exception as exc:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we be more specific w/ this exception?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to do this in a follow-up, but let's try to be more specific here


config.build.arguments[
model_key
] = f"/app/hf_cache/{model_name.replace('gs://', '')}"
] = f"/app/model_cache/{model_name.replace('gs://', '')}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we move this string into a function w/ comments? it's not clear to me why we need to do this transformation of the config object

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very specifically for TGI/vLLM, where the model maybe specified directly as an HF repo. If it's a GCS or S3 bucket, we want to alias that bucket to the cache and make sure the model server pulls from the cache instead of throwing an error.

@@ -1,5 +1,5 @@
{% for file in cached_files %}
{%- if credentials_exists %}
{%- if file.startswith("/app/model_cache/") %}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be a little bit cleaner if we could have these templates be more logicless. Is there something else that we can check here? It's not clear from reading this template file what the implications of the file being named /app/model_cache/ are

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the cache copying mechanism. The HuggingFace files have a special root directory while other files do not. Let me think about this and get back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to keep the .startswith but instead just use /app/ instead. If the file is relative to app we want to copy it to the same place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HuggingFace files have a special root directory while other files do not. Let me think about this and get back."

This assumption is baked in here & implicit but not made explicit anywhere. In the future someone might wonder why we're doing this. Here's an example of an approach that makes this explicit instead of implicit:

@dataclass
class CachedFile:
    source: str
    dst: str

cached_files = [
  # Huggingface files have a special root directory, while the others do not.
  # Being in app/model_cache implies that it is not a huggingface file
  CachedFile(src=... if file.startswith(...) , dst=...) for file in files
]

data = {
   ...
  cached_files: cached_files
  ...
}

render_template(data)

Copy link
Collaborator

@squidarth squidarth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple more comments! Lmk when you've tested it again and i'll throw a ✅

raise RuntimeError(f"Failure due to file ({file_name}) not found: {file_error}")

except TimeoutError as timeout_error:
raise RuntimeError(f"Failure due to timeout: {timeout_error}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TimeoutError, OSError, and ValueError, why not just throw that exception (Instead of catching and reraising a RuntimeError)? I don't think the RuntimeError adds anything here

@@ -1,7 +1,3 @@
{% for file in cached_files %}
{%- if credentials_exists %}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

WORKDIR /app

{% if hf_access_token %}
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
{% endif %}
{%- if credentials_exists %}
{%- if gcs_credentials_exists %}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we repeat these magic file paths in a lot of places, I wonder if we could just pass in credentials here, and can just do:

COPY ./data/{{ credentials}} ...

instead of having branching logic? And then we can define these constants in one place

self.revision = revision

@staticmethod
def from_repo(repo_name, data_dir):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add types

)
model_cache = RemoteCache.from_repo(repo_id, truss_dir / config.data_dir)
remote_filtered_files = model_cache.filter(allow_patterns, ignore_patterns)
local_cached_files += model_cache.prepare_for_cache(remote_filtered_files)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think local_cached_files makes it seem like they are already cached. That hasn't happened yet, maybe files_to_cache?

Copy link
Collaborator

@squidarth squidarth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work here! I think there's a little bit of cleanup we can do in the serving_builder, but let's try to get this in.

The only thing that I'd consider is moving the doc changes to a different PR if you want to merge this now. If we merge this now, the docs will automatically deploy and be incorrect until we push the new context builder.

So I'd say if you want to merge this now, let's move the doc changes to a different PR, else, we can merge on mon and do a new context builder

)
except FileNotFoundError:
except Exception as exc:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to do this in a follow-up, but let's try to be more specific here

@varunshenoy varunshenoy merged commit 1dc9be5 into main Nov 6, 2023
3 checks passed
@varunshenoy varunshenoy deleted the varun/cache-refactor branch November 6, 2023 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants