Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support libp2p relays for NAT traversal #186

Merged
merged 11 commits into from
Jan 9, 2023
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
huggingface-hub==0.11.1
transformers==4.25.1
speedtest-cli==2.1.3
hivemind==1.1.3
hivemind==1.1.5
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
Expand Down
8 changes: 7 additions & 1 deletion src/petals/cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def main():
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")

parser.add_argument("--auto_relay", action='store_true', help="Enabling relay for NAT traversal")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument("--auto_relay", action='store_true', help="Enabling relay for NAT traversal")
parser.add_argument("--auto_relay", action='store_true', help="Enable relay for NAT traversal")

parser.add_argument('--no-auto_relay', dest='auto_relay', action='store_false')
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")

parser.add_argument('--load_in_8bit', type=str, default=None,
help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
Expand All @@ -140,7 +143,7 @@ def main():
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")

parser.set_defaults(auto_relay=True)
# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
Expand All @@ -158,6 +161,8 @@ def main():

announce_maddrs = args.pop("announce_maddrs")
public_ip = args.pop("public_ip")
use_auto_relay = args.pop("auto_relay")

if public_ip is not None:
assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time"
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
Expand Down Expand Up @@ -197,6 +202,7 @@ def main():
compression=compression,
max_disk_space=max_disk_space,
attn_cache_size=attn_cache_size,
use_auto_relay=use_auto_relay,
)
try:
server.run()
Expand Down
2 changes: 2 additions & 0 deletions src/petals/client/remote_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(self, config: DistributedBloomConfig):
num_workers=n_layer,
startup_timeout=config.daemon_startup_timeout,
start=True,
use_relay=True,
use_auto_relay=True,
)
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
Expand Down
11 changes: 10 additions & 1 deletion src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def __init__(
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
use_relay: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to either remove default argument values of to remove these arguments completely: we might forget to change defaults here in the future, and required values will be passed to kwargs anyway

Copy link
Collaborator

@borzunov borzunov Jan 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point but I think we should keep it, since the convention in Petals is that all Server defaults match to the defaults of run_server.py (in turn, the hivemind default for use_auto_relay is different).

But all you said would have applied if the defaults here matched with hivemind.

Copy link
Collaborator Author

@Vahe1994 Vahe1994 Jan 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will agree with @borzunov here. Here are my arguments:

  1. use_relay will not be passed from run_server and we want it by default to be True
  2. it is nice to see in the arguments all parameters that matters
  3. usually , it is not a good idea to be dependent on default argument from another library . They could be changed without notice and can lead to strange behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point about explicitly indicating arguments for creation of Server, though it somewhat contradicts the existence of **kwargs in init. My primary concern is that we should strive to have consistent defaults across different locations: one way to do this in an error-proof way would be to declare a common constant with the default value and use it in both locations. Besides, petals-cli is a part of Petals, so these files belong to the same library

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to create constants for all defaults in this case (tens of them). I think this is a more general problem that should be addressed outside of this PR (maybe we should use smth like reflection).

use_auto_relay: bool = True,
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
Expand Down Expand Up @@ -117,7 +119,14 @@ def __init__(
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]

self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
**kwargs,
)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
Expand Down