Skip to content

Commit

Permalink
fix #4137
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 7, 2024
1 parent cce0fad commit 8bf9da6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/llamafactory/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,8 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
disable_torchrun = os.environ.get("TORCHRUN_DISABLED", "0").lower() in ["true", "1"]
if disable_torchrun and get_device_count() > 1:
logger.warning("`torchrun` cannot be disabled when device count > 1.")
disable_torchrun = False

if (not disable_torchrun) and (get_device_count() > 0):
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
if force_torchrun or get_device_count() > 1:
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
Expand Down
3 changes: 3 additions & 0 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dic
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1"

self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
yield from self.monitor()

Expand Down

0 comments on commit 8bf9da6

Please sign in to comment.