Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA_VISIBLE_DEVICES environment variable when using the --gpus flag #345

Merged
merged 8 commits into from
Jan 5, 2024
12 changes: 8 additions & 4 deletions mlcube/mlcube/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ def parse_extra_arg(
key = "--security-opt" if platform == "docker" else "--security"
runner_run_args[key] = parsed_args["security"]
if parsed_args.get("gpus", None):
cuda_visible_devices = parsed_args["gpus"]
if "device" in cuda_visible_devices:
cuda_visible_devices = cuda_visible_devices.replace("device=", "")
elif str(cuda_visible_devices).isnumeric():
cuda_visible_devices = str(list(range(int(cuda_visible_devices))))
cuda_visible_devices = cuda_visible_devices.replace(" ", "")[1:-1]
if platform == "docker":
runner_run_args["--gpus"] = parsed_args["gpus"]
runner_run_args["--gpus"] = cuda_visible_devices
else:
runner_run_args["--nv"] = ""
os.environ["SINGULARITYENV_CUDA_VISIBLE_DEVICES"] = parsed_args[
"gpus"
]
os.environ["SINGULARITYENV_CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
if parsed_args.get("memory", None):
key = "--memory" if platform == "docker" else "--vm-ram"
runner_run_args[key] = parsed_args["memory"]
Expand Down
13 changes: 13 additions & 0 deletions runners/mlcube_docker/mlcube_docker/docker_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,19 @@ def run(self) -> None:
if extra_args:
run_args += " " + extra_args

valid_gpu_flag = "--gpus" in self.mlcube.runner and self.mlcube.runner["--gpus"] is not None

if valid_gpu_flag:
cuda_visible_devices = self.mlcube.runner["--gpus"]
else:
cuda_visible_devices = num_gpus

if str(cuda_visible_devices).isnumeric():
cuda_visible_devices = str(list(range(int(cuda_visible_devices))))
cuda_visible_devices = cuda_visible_devices.replace(" ", "")[1:-1]

run_args += f" --env CUDA_VISIBLE_DEVICES={cuda_visible_devices}"

if "entrypoint" in self.mlcube.tasks[self.task]:
logger.info(
"Using custom task entrypoint: task=%s, entrypoint='%s'",
Expand Down
Loading