Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference: require max sequence length instead of assuming 2048 #52

Merged
merged 19 commits into from
Aug 29, 2022

Conversation

justheuristic
Copy link
Collaborator

@justheuristic justheuristic commented Aug 28, 2022

  • Maximum length is now provided in .inference_session(max_length=100)

    • previously, we would always assume max length = 2048
  • added a generic way to forward **kwargs to inference session

    • for compatibility with Priority tasks #47
    • Note to @borzunov : it does not pass them arbitrarily, but instead checks for kwarg names at the bottom level
  • run_server can be started with a custom max_length for inference

  • renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir)

  • --attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800)

  • made some server-side errors more human-readable to user (e.g. when max length is exceeded)



def parse_size_as_bytes(size: str) -> int:
"""parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cuts some corners (e.g. 1GB = 1GiB). We can get it right all the time by using https://pypi.org/project/humanfriendly/ , but I'm not sure it justifies the extra depenedcy

@borzunov any preferences?

Copy link
Collaborator

@borzunov borzunov Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I'm okay with adding this light dependency. I think it's good if we don't have extra code in the PETALS codebase to keep it small.
  • I'm not okay (weakly) with the same behavior for GB and GiB.

If you don't want a new dependency, to support both GB and GiB in a correct way, you can actually simplify the code to do smth like:

units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
unit = unit.upper()
if 'I' in unit:
    unit = 1024 ** units.index(unit.replace('I', ''))
else:
    unit = 1000 ** units.index(unit)

In other words, you can use the same array of units, just with different bases :)

@@ -8,6 +8,20 @@
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

import re
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to the top of the file

parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attention_cache_bytes', type=str, default=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument('--attention_cache_bytes', type=str, default=None,
parser.add_argument('--attn_cache_size', type=str, default=None,

I strongly advise to replace cache_bytes to cache_size because:

  • Specifying the size in bytes is standard across Python libs
  • Moreover, here the size can be specified in any units

Also, I weakly advise to replace attention to attn because it's much shorter (but still understandable to everyone). This is optional though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[done both]

tests/test_chained_calls.py Show resolved Hide resolved
@@ -135,13 +138,15 @@ def create(
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
if inference_max_length is None:
inference_max_length = max_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we assign max batch size to max sequence size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both are (meant to be) in tokens

Would you prefer to set it to a constant by default?

cli/run_server.py Outdated Show resolved Hide resolved
src/server/handler.py Outdated Show resolved Hide resolved


def parse_size_as_bytes(size: str) -> int:
"""parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
Copy link
Collaborator

@borzunov borzunov Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I'm okay with adding this light dependency. I think it's good if we don't have extra code in the PETALS codebase to keep it small.
  • I'm not okay (weakly) with the same behavior for GB and GiB.

If you don't want a new dependency, to support both GB and GiB in a correct way, you can actually simplify the code to do smth like:

units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
unit = unit.upper()
if 'I' in unit:
    unit = 1024 ** units.index(unit.replace('I', ''))
else:
    unit = 1000 ** units.index(unit)

In other words, you can use the same array of units, just with different bases :)

.github/workflows/run-tests.yaml Outdated Show resolved Hide resolved

if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert isinstance(max_length, int), f"rpc_inference metadata must contain int seq_length, got {max_length}"
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"

@justheuristic justheuristic merged commit d271b75 into main Aug 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants