-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference: require max sequence length instead of assuming 2048 #52
Conversation
cli/run_server.py
Outdated
|
||
|
||
def parse_size_as_bytes(size: str) -> int: | ||
"""parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cuts some corners (e.g. 1GB = 1GiB). We can get it right all the time by using https://pypi.org/project/humanfriendly/ , but I'm not sure it justifies the extra depenedcy
@borzunov any preferences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I'm okay with adding this light dependency. I think it's good if we don't have extra code in the PETALS codebase to keep it small.
- I'm not okay (weakly) with the same behavior for
GB
andGiB
.
If you don't want a new dependency, to support both GB
and GiB
in a correct way, you can actually simplify the code to do smth like:
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
unit = unit.upper()
if 'I' in unit:
unit = 1024 ** units.index(unit.replace('I', ''))
else:
unit = 1000 ** units.index(unit)
In other words, you can use the same array of units, just with different bases :)
cli/run_server.py
Outdated
@@ -8,6 +8,20 @@ | |||
use_hivemind_log_handler("in_root_logger") | |||
logger = get_logger(__file__) | |||
|
|||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to the top of the file
cli/run_server.py
Outdated
parser.add_argument('--device', type=str, default=None, required=False, | ||
help='all experts will use this device in torch notation; default: cuda if available else cpu') | ||
parser.add_argument("--torch_dtype", type=str, default="auto", | ||
help="Use this dtype to store block weights and do computations. " | ||
"By default, respect the dtypes in the pre-trained state dict.") | ||
parser.add_argument('--attention_cache_bytes', type=str, default=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parser.add_argument('--attention_cache_bytes', type=str, default=None, | |
parser.add_argument('--attn_cache_size', type=str, default=None, |
I strongly advise to replace cache_bytes
to cache_size
because:
- Specifying the size in bytes is standard across Python libs
- Moreover, here the size can be specified in any units
Also, I weakly advise to replace attention
to attn
because it's much shorter (but still understandable to everyone). This is optional though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[done both]
src/server/server.py
Outdated
@@ -135,13 +138,15 @@ def create( | |||
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" | |||
if expiration is None: | |||
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) | |||
if inference_max_length is None: | |||
inference_max_length = max_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we assign max batch size to max sequence size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both are (meant to be) in tokens
Would you prefer to set it to a constant by default?
cli/run_server.py
Outdated
|
||
|
||
def parse_size_as_bytes(size: str) -> int: | ||
"""parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I'm okay with adding this light dependency. I think it's good if we don't have extra code in the PETALS codebase to keep it small.
- I'm not okay (weakly) with the same behavior for
GB
andGiB
.
If you don't want a new dependency, to support both GB
and GiB
in a correct way, you can actually simplify the code to do smth like:
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
unit = unit.upper()
if 'I' in unit:
unit = 1024 ** units.index(unit.replace('I', ''))
else:
unit = 1000 ** units.index(unit)
In other words, you can use the same array of units, just with different bases :)
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
src/server/handler.py
Outdated
|
||
if not requested_uids: | ||
raise ValueError("User must specify at least one block for inference, but got none") | ||
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}" | |
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}" |
Maximum length is now provided in
.inference_session(max_length=100)
added a generic way to forward **kwargs to inference session
run_server can be started with a custom max_length for inference
renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir)
--attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800)
made some server-side errors more human-readable to user (e.g. when max length is exceeded)