Skip to content

Commit

Permalink
Use torch.cuda.device_count() (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaden Smith committed Mar 10, 2020
1 parent 8ad8a26 commit 27d8385
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 17 deletions.
21 changes: 5 additions & 16 deletions deepspeed/pt/deepspeed_run.py
Expand Up @@ -5,14 +5,16 @@
import os
import sys
import json
import pynvml
import shutil
import base64
import logging
import argparse
import subprocess
import collections
from copy import deepcopy

import torch.cuda

from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT

DLTS_HOSTFILE = "/job/hostfile"
Expand Down Expand Up @@ -213,19 +215,6 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
exclude_str=exclusion)


def local_gpu_count():
device_count = None
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
print("device count", device_count)
return device_count
except pynvml.NVMLError:
logging.error("Unable to get GPU count information, perhaps there are "
"no GPUs on this host?")
return device_count


def encode_world_info(world_info):
world_info_json = json.dumps(world_info).encode('utf-8')
world_info_base64 = base64.urlsafe_b64encode(world_info_json).decode('utf-8')
Expand All @@ -243,8 +232,8 @@ def main(args=None):
resource_pool = fetch_hostfile(args.hostfile)
if not resource_pool:
resource_pool = {}
device_count = local_gpu_count()
if device_count is None:
device_count = torch.cuda.device_count()
if device_count == 0:
raise RuntimeError("Unable to proceed, no GPU resources available")
resource_pool['localhost'] = device_count
args.master_addr = "127.0.0.1"
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Expand Up @@ -5,7 +5,6 @@ tqdm
psutil
tensorboardX==1.8
tensorflow-gpu==1.15.2
nvidia-ml-py3
pytest
pytest-forked
pre-commit

0 comments on commit 27d8385

Please sign in to comment.