-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache Refactor and Improvements #710
Conversation
@@ -364,15 +371,17 @@ def create_vllm_build_dir( | |||
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja") | |||
|
|||
data_dir = build_dir / "data" | |||
credentials_file = data_dir / "service_account.json" | |||
gcs_credentials_file = data_dir / "service_account.json" | |||
s3_credentials_file = data_dir / "s3_credentials.json" | |||
dockerfile_content = dockerfile_template.render( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is starting to get very unwieldy — is there a better way to format it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what specifically did you have in mind here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good, thanks for making these changes. My main high-level feedback is that while we're at this, we should rename a couple other things:
- HuggingFaceCache -> ModelCache
- HuggingFaceCache.repo_id -> ModelCache.path
wdyt?
try: | ||
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) | ||
proc.wait() | ||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should almost never do except Exception
. Imagine we mispelled proc.wait()
as proc.wai
. This would throw an attribute not found exception, and it would be very hard to figure out. Let's instead enumerate the network-related errors that could happen here.
# open the json file | ||
with open(file_path, "r") as f: | ||
data = json.load(f) | ||
class RepositoryFile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make this an abstract base class (https://docs.python.org/3/library/abc.html)
self.is_private = True | ||
|
||
@staticmethod | ||
def create(repo_name, file_name, revision_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I tend to prefer from
over create
@@ -35,6 +36,23 @@ def _download_from_url_using_b10cp( | |||
) | |||
|
|||
|
|||
def parse_s3_service_account_file(file_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add input & output types
def parse_s3_service_account_file(file_path): | ||
# open the json file | ||
with open(file_path, "r") as f: | ||
data = json.load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider using something like python dataclass or pydantic to define the type.
try: | ||
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) | ||
proc.wait() | ||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see below my note about using except Exception as e
cache_dir = Path(f"/app/model_cache/{self.bucket_name}") | ||
cache_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
dst_file = Path(f"{cache_dir}/{self.file_name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dst_file = cache_dir / self.file_name
or dst_file = cache_dir / Path(self.file_name)
should work here
aws_secret_access_key = data["aws_secret_access_key"] | ||
aws_region = data["aws_region"] | ||
class GCSFile(RepositoryFile): | ||
def connect(self, key_file="/app/data/service_account.json"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the point of this connect
function? could we do all of this as a part of download
?
truss/truss_config.py
Outdated
@@ -502,6 +503,12 @@ def from_dict(d): | |||
def from_yaml(yaml_path: Path): | |||
with yaml_path.open() as yaml_file: | |||
raw_data = yaml.safe_load(yaml_file) or {} | |||
if "hf_cache" in raw_data: | |||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use logger
instead of warnings
here
def test_null_hf_cache_key(): | ||
config_yaml_dict = {"hf_cache": None} | ||
def test_null_model_cache_key(): | ||
config_yaml_dict = {"model_cache": None} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case I missed this, could we have a case where there's a yaml file with the key "hf_cache" and check that it parses out correctly?
Agree with HuggingFaceCache -> ModelCache, but think it might be better to keep |
k - I think that's fine. Still technically makes sense with gcs & s3 |
@@ -364,15 +370,17 @@ def create_vllm_build_dir( | |||
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja") | |||
|
|||
data_dir = build_dir / "data" | |||
credentials_file = data_dir / "service_account.json" | |||
gcs_credentials_file = data_dir / "service_account.json" | |||
s3_credentials_file = data_dir / "s3_credentials.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we move these key files to constants?
@@ -253,13 +257,13 @@ def fetch_files_to_cache(cached_files: list, repo_id: str, filtered_repo_files: | |||
repo_id = f"gs://{bucket_name}" | |||
|
|||
for filename in filtered_repo_files: | |||
cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}") | |||
cached_files.append(f"/app/model_cache/{bucket_name}/{filename}") | |||
elif repo_id.startswith("s3://"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you use teh new nice clases that you made for this instead of big if statement?
@@ -364,15 +371,17 @@ def create_vllm_build_dir( | |||
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja") | |||
|
|||
data_dir = build_dir / "data" | |||
credentials_file = data_dir / "service_account.json" | |||
gcs_credentials_file = data_dir / "service_account.json" | |||
s3_credentials_file = data_dir / "s3_credentials.json" | |||
dockerfile_content = dockerfile_template.render( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what specifically did you have in mind here?
# Create S3 Client | ||
bucket_name, _ = split_path(repo_name, prefix="s3://") | ||
|
||
key_file = "/app/data/s3_credentials.jso" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo? Also let's move this path to a constant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varunshenoy how did this work before?
except ValueError as value_error: | ||
raise RuntimeError(f"Failure due to an error: {value_error}") | ||
|
||
except Exception as general_error: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to just let this issue throw, we don't need to catch it
) | ||
except FileNotFoundError: | ||
except Exception as exc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we be more specific w/ this exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to do this in a follow-up, but let's try to be more specific here
|
||
config.build.arguments[ | ||
model_key | ||
] = f"/app/hf_cache/{model_name.replace('gs://', '')}" | ||
] = f"/app/model_cache/{model_name.replace('gs://', '')}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we move this string into a function w/ comments? it's not clear to me why we need to do this transformation of the config object
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very specifically for TGI/vLLM, where the model maybe specified directly as an HF repo. If it's a GCS or S3 bucket, we want to alias that bucket to the cache and make sure the model server pulls from the cache instead of throwing an error.
@@ -1,5 +1,5 @@ | |||
{% for file in cached_files %} | |||
{%- if credentials_exists %} | |||
{%- if file.startswith("/app/model_cache/") %} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be a little bit cleaner if we could have these templates be more logicless. Is there something else that we can check here? It's not clear from reading this template file what the implications of the file being named /app/model_cache/ are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the cache copying mechanism. The HuggingFace files have a special root directory while other files do not. Let me think about this and get back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might make sense to keep the .startswith
but instead just use /app/
instead. If the file is relative to app
we want to copy it to the same place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HuggingFace files have a special root directory while other files do not. Let me think about this and get back."
This assumption is baked in here & implicit but not made explicit anywhere. In the future someone might wonder why we're doing this. Here's an example of an approach that makes this explicit instead of implicit:
@dataclass
class CachedFile:
source: str
dst: str
cached_files = [
# Huggingface files have a special root directory, while the others do not.
# Being in app/model_cache implies that it is not a huggingface file
CachedFile(src=... if file.startswith(...) , dst=...) for file in files
]
data = {
...
cached_files: cached_files
...
}
render_template(data)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple more comments! Lmk when you've tested it again and i'll throw a ✅
raise RuntimeError(f"Failure due to file ({file_name}) not found: {file_error}") | ||
|
||
except TimeoutError as timeout_error: | ||
raise RuntimeError(f"Failure due to timeout: {timeout_error}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For TimeoutError, OSError, and ValueError, why not just throw that exception (Instead of catching and reraising a RuntimeError)? I don't think the RuntimeError adds anything here
@@ -1,7 +1,3 @@ | |||
{% for file in cached_files %} | |||
{%- if credentials_exists %} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome
WORKDIR /app | ||
|
||
{% if hf_access_token %} | ||
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}} | ||
{% endif %} | ||
{%- if credentials_exists %} | ||
{%- if gcs_credentials_exists %} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we repeat these magic file paths in a lot of places, I wonder if we could just pass in credentials here, and can just do:
COPY ./data/{{ credentials}} ...
instead of having branching logic? And then we can define these constants in one place
self.revision = revision | ||
|
||
@staticmethod | ||
def from_repo(repo_name, data_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add types
) | ||
model_cache = RemoteCache.from_repo(repo_id, truss_dir / config.data_dir) | ||
remote_filtered_files = model_cache.filter(allow_patterns, ignore_patterns) | ||
local_cached_files += model_cache.prepare_for_cache(remote_filtered_files) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think local_cached_files
makes it seem like they are already cached. That hasn't happened yet, maybe files_to_cache
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome work here! I think there's a little bit of cleanup we can do in the serving_builder, but let's try to get this in.
The only thing that I'd consider is moving the doc changes to a different PR if you want to merge this now. If we merge this now, the docs will automatically deploy and be incorrect until we push the new context builder.
So I'd say if you want to merge this now, let's move the doc changes to a different PR, else, we can merge on mon and do a new context builder
) | ||
except FileNotFoundError: | ||
except Exception as exc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to do this in a follow-up, but let's try to be more specific here
This PR adds the following features:
cache_warmer.py
model_cache
tohf_cache
.app/model_cache
instead ofapp/hf_cache
Tested the following on dev:
model_cache