Skip to content

Commit

Permalink
Refactor CLI arg to look as --use_auto_relay False/True
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jan 8, 2023
1 parent d641b9b commit d99d710
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/petals/cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
logger = get_logger(__file__)


TRUE_CONSTANTS = ["TRUE", "1"]


def main():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"],
Expand Down Expand Up @@ -127,8 +130,9 @@ def main():
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")

parser.add_argument("--auto_relay", action='store_true', help="Enabling relay for NAT traversal")
parser.add_argument('--no-auto_relay', dest='auto_relay', action='store_false')
parser.add_argument("--use_auto_relay", type=str, default="True",
help="Look for libp2p relays for NAT traversal. "
"Use `--use_auto_relay False/True` to disable/enable this")
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")

parser.add_argument('--load_in_8bit', type=str, default=None,
Expand All @@ -143,7 +147,7 @@ def main():
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.set_defaults(auto_relay=True)

# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
Expand All @@ -161,8 +165,6 @@ def main():

announce_maddrs = args.pop("announce_maddrs")
public_ip = args.pop("public_ip")
use_auto_relay = args.pop("auto_relay")

if public_ip is not None:
assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time"
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
Expand Down Expand Up @@ -191,9 +193,11 @@ def main():
if args.pop("new_swarm"):
args["initial_peers"] = []

use_auto_relay = args.pop("use_auto_relay").upper() in TRUE_CONSTANTS

load_in_8bit = args.pop("load_in_8bit")
if load_in_8bit is not None:
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
load_in_8bit = load_in_8bit.upper() in TRUE_CONSTANTS

server = Server(
**args,
Expand All @@ -203,6 +207,7 @@ def main():
max_disk_space=max_disk_space,
attn_cache_size=attn_cache_size,
use_auto_relay=use_auto_relay,
load_in_8bit=load_in_8bit,
)
try:
server.run()
Expand Down

0 comments on commit d99d710

Please sign in to comment.