Skip to content

Commit 7edabdb

Browse files
authored
Merge branch 'master' into olruwase/zero_multi_models
2 parents 3b86860 + 729dfaf commit 7edabdb

File tree

12 files changed

+32
-60
lines changed

12 files changed

+32
-60
lines changed

.github/workflows/no-torch.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ jobs:
3232
run: |
3333
pip uninstall torch --yes
3434
pip install setuptools
35+
pip install build
3536
pip list
3637
3738
- name: Build deepspeed
3839
run: |
39-
DS_BUILD_STRING=" " python setup.py sdist
40+
DS_BUILD_STRING=" " python -m build --sdist
4041
4142
- name: Open GitHub issue if nightly CI fails
4243
if: ${{ failure() && (github.event_name == 'schedule') }}

.github/workflows/release.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ jobs:
2626
- name: Build DeepSpeed
2727
run: |
2828
pip install setuptools
29-
DS_BUILD_STRING=" " python setup.py sdist
29+
pip install build
30+
DS_BUILD_STRING=" " python -m build --sdist
3031
- name: Publish to PyPI
3132
uses: pypa/gh-action-pypi-publish@release/v1
3233
with:

build_win.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ set DS_BUILD_GDS=0
1111
set DS_BUILD_RAGGED_DEVICE_OPS=0
1212
set DS_BUILD_SPARSE_ATTN=0
1313

14-
python setup.py bdist_wheel
14+
python -m build --wheel --no-isolation
1515

1616
:end

csrc/aio/common/deepspeed_aio_common.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,13 @@ int open_file(const char* filename, const bool read_op)
284284

285285
int regular_read(const char* filename, std::vector<char>& buffer)
286286
{
287-
int64_t num_bytes;
288-
const auto f_size = get_file_size(filename, num_bytes);
289-
assert(f_size != -1);
290-
buffer.resize(num_bytes);
291287
const auto fd = open(filename, O_RDONLY, 0600);
292288
assert(fd != -1);
289+
struct stat fs;
290+
const auto result = fstat(fd, &fs);
291+
assert(result != -1);
292+
int64_t num_bytes = fs.st_size;
293+
buffer.resize(num_bytes);
293294
int64_t read_bytes = 0;
294295
auto r = 0;
295296
do {

deepspeed/comm/torch.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,11 @@ def has_reduce_scatter_tensor(self):
145145

146146
def init_process_group(self, backend, timeout, init_method, rank, world_size):
147147
if not torch.distributed.is_initialized():
148-
if not required_torch_version(min_version=2.4):
149-
# Windows torch builds do not come with lib_uv by default.
150-
# More information here: https://pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html
151-
use_libuv = False if os.name == "nt" else True
152-
torch.distributed.init_process_group(backend,
153-
timeout=timeout,
154-
init_method=init_method,
155-
rank=rank,
156-
world_size=world_size,
157-
use_libuv=use_libuv)
158-
else:
159-
torch.distributed.init_process_group(backend,
160-
timeout=timeout,
161-
init_method=init_method,
162-
rank=rank,
163-
world_size=world_size)
148+
torch.distributed.init_process_group(backend,
149+
timeout=timeout,
150+
init_method=init_method,
151+
rank=rank,
152+
world_size=world_size)
164153
self.using_mpi = torch.distributed.get_backend() == 'mpi'
165154

166155
@disable_compiler_collective

deepspeed/module_inject/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def move(tensor, device):
4848
# to save host resources when DP > 1。
4949

5050
if tensor.is_meta:
51-
return torch.empty_like(tensor, device=device)
51+
# Keep tensor in meta device if tensor is meta.
52+
return tensor
5253
else:
5354
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
5455
# Using copy=True instead of clone() will help in case of cpu --> cpu.

deepspeed/utils/logging.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import logging
88
import sys
99
import os
10-
from deepspeed.runtime.compiler import is_compile_supported, is_compiling
10+
import torch
11+
from deepspeed.utils.torch import required_torch_version
1112

1213
log_levels = {
1314
"debug": logging.DEBUG,
@@ -20,31 +21,6 @@
2021

2122
class LoggerFactory:
2223

23-
def create_warning_filter(logger):
24-
warn = False
25-
26-
def warn_once(record):
27-
nonlocal warn
28-
if is_compile_supported() and is_compiling() and not warn:
29-
warn = True
30-
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
31-
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
32-
return True
33-
34-
return warn_once
35-
36-
@staticmethod
37-
def logging_decorator(func):
38-
39-
@functools.wraps(func)
40-
def wrapper(*args, **kwargs):
41-
if is_compiling():
42-
return
43-
else:
44-
return func(*args, **kwargs)
45-
46-
return wrapper
47-
4824
@staticmethod
4925
def create_logger(name=None, level=logging.INFO):
5026
"""create a logger
@@ -70,12 +46,15 @@ def create_logger(name=None, level=logging.INFO):
7046
ch.setLevel(level)
7147
ch.setFormatter(formatter)
7248
logger_.addHandler(ch)
73-
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
74-
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
49+
if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
50+
excluded_set = {
51+
item.strip()
52+
for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",")
53+
}
54+
ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set
55+
for method in ignore_set:
7556
original_logger = getattr(logger_, method)
76-
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
77-
else:
78-
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
57+
torch._dynamo.config.ignore_logger_methods.add(original_logger)
7958
return logger_
8059

8160

docs/_tutorials/advanced-install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ This should complete the full build 2-3 times faster. You can adjust `-j` to spe
8484
You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, PyTorch, Python, etc.)
8585

8686
```bash
87-
DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel
87+
DS_BUILD_OPS=1 python -m build --wheel --no-isolation --config-setting="--build-option=build_ext" --config-setting="--build-option=-j8"
8888
```
8989

9090
This will create a pypi binary wheel under `dist`, e.g., ``dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` and then you can install it directly on multiple machines, in our example:

docs/_tutorials/ds-sequence.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pip install .
111111
cd ${WORK_DIR}
112112
git clone -b v1.0.4 https://github.com/HazyResearch/flash-attention
113113
cd flash-attention
114-
python setup.py install
114+
python -m pip install .
115115
```
116116

117117
You may also want to ensure your model configuration is compliant with FlashAttention's requirements. For instance, to achieve optimal performance, the head size should be divisible by 8. Refer to the FlashAttention documentation for more details.

install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ if [ ! -f $hostfile ]; then
152152
fi
153153

154154
echo "Building deepspeed wheel"
155-
python setup.py $VERBOSE bdist_wheel
155+
python -m build $VERBOSE --wheel --no-isolation
156156

157157
if [ "$local_only" == "1" ]; then
158158
echo "Installing deepspeed"

0 commit comments

Comments
 (0)