Skip to content

Commit

Permalink
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
Browse files Browse the repository at this point in the history
…/ds_2921
  • Loading branch information
tjruwase committed Oct 25, 2023
2 parents f25ff5b + 869629c commit f638d92
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
6 changes: 4 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@

# Required
version: 2
build:
os: "ubuntu-22.04"
tools:
python: "3.8"

# Build documentation in the docs/ directory with Sphinx
sphinx:
Expand All @@ -13,6 +16,5 @@ formats:

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- requirements: requirements/requirements-readthedocs.txt
3 changes: 3 additions & 0 deletions csrc/includes/cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#ifndef __HIP_PLATFORM_AMD__
#include <mma.h>
#endif
#ifdef __HIP_PLATFORM_AMD__
#include <rocblas/rocblas.h>
#endif
#include <stdio.h>

int cublas_gemm_ex(cublasHandle_t handle,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _aggregate_total_loss(self):
agg_loss /= self.dp_world_size

assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
losses = torch.stack([self.dp_group_loss, agg_loss])
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
Expand Down
4 changes: 3 additions & 1 deletion requirements/requirements-readthedocs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ hjson
packaging
psutil
py-cpuinfo
pydantic
pydantic<2.0.0
recommonmark
sphinx_rtd_theme
torch
tqdm

0 comments on commit f638d92

Please sign in to comment.