Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA optional deepspeed ops #2507

Merged
merged 65 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
7f811bf
CPU-Adam: add compile-flag to enable param-copy from CPU to GPU
Aug 5, 2021
2422ddf
guarde the CUDA-related include files and variables
Aug 5, 2021
e181f45
Merge branch 'master' into cpu-adam/optional_CUDA-copy
tjruwase Aug 5, 2021
248bb35
Merge branch 'master' into cpu-adam/optional_CUDA-copy
RezaYazdaniAminabadi Aug 19, 2021
2d8b7df
remove CUDA dependency from op_builder when building against CPU
Aug 19, 2021
8ce051c
Merge branch 'master' into cpu-adam/optional_CUDA-copy
RezaYazdaniAminabadi Aug 26, 2021
6eda144
fixing the builder issues
Aug 27, 2021
71a4d3d
fix formatting
Aug 27, 2021
52f7f7b
Merge branch 'master' into cpu-adam/optional_CUDA-copy
RezaYazdaniAminabadi Aug 27, 2021
ee35deb
return true when there is no mismatch on the cuda version
Aug 30, 2021
8950731
guard for when cuda is not available & test with cpu-only environment
Sep 7, 2021
04d76d5
Merge branch 'master' into cpu-adam/optional_CUDA-copy
RezaYazdaniAminabadi Sep 7, 2021
5625007
Update cpu_adam and cpu_adagrad
tjruwase Nov 15, 2022
f30dffa
merge with master
tjruwase Nov 15, 2022
52ef0ea
Format fixes
tjruwase Nov 15, 2022
472c7ba
Add configurable half precision type; Build/run in CUDA environment
tjruwase Nov 15, 2022
70a0598
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Nov 15, 2022
bf3ca2f
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Nov 16, 2022
142fb21
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Nov 29, 2022
63ef769
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Nov 30, 2022
b031aa9
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 14, 2022
c7ae7d1
Run cpu_adam and cpu_adagrad in cpu only environment
tjruwase Dec 14, 2022
7dabb62
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 16, 2022
1d9c9ac
Mark CUDA only unit tests
tjruwase Dec 16, 2022
291fb13
Merge branch 'olruwase/cuda_optional_ops' of github.com:microsoft/Dee…
tjruwase Dec 16, 2022
f19532e
CPU environment CI
tjruwase Dec 16, 2022
5c85a75
Format fixes
tjruwase Dec 16, 2022
f26bf96
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 16, 2022
6973079
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 16, 2022
980e994
Remove --forked
tjruwase Dec 17, 2022
99c4716
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 17, 2022
26a719e
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 18, 2022
fdb5c4c
Add --forked
tjruwase Dec 19, 2022
0589fc1
Merge branch 'olruwase/cuda_optional_ops' of github.com:microsoft/Dee…
tjruwase Dec 19, 2022
214f2c8
CPU only CI should pass
tjruwase Dec 20, 2022
7fd2d26
Format fixes
tjruwase Dec 20, 2022
6241f44
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 20, 2022
c3a3999
Format fixes
tjruwase Dec 20, 2022
613ca7e
Merge branch 'olruwase/cuda_optional_ops' of github.com:microsoft/Dee…
tjruwase Dec 20, 2022
fea4d85
Remove scattered pytest.skip
tjruwase Dec 20, 2022
5ed6ef1
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 21, 2022
8ca082d
Fix cpu_adam unit test
tjruwase Dec 21, 2022
8da0c26
Merge branch 'olruwase/cuda_optional_ops' of github.com:microsoft/Dee…
tjruwase Dec 21, 2022
86d72a4
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Dec 21, 2022
7fe7d50
Merge branch 'master' into olruwase/cuda_optional_ops
jeffra Dec 21, 2022
1ebb130
Update .github/workflows/nv-torch-latest-cpu.yml
tjruwase Dec 21, 2022
7c14844
Update .github/workflows/nv-torch-latest-cpu.yml
tjruwase Dec 21, 2022
e1f7372
Address PR feedback
tjruwase Dec 21, 2022
3583f97
OpenMP linking
tjruwase Dec 22, 2022
58d4106
Rebase
tjruwase Dec 23, 2022
f654f6a
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 7, 2023
fcf8c33
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 8, 2023
813dbba
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 9, 2023
f5a44d3
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 10, 2023
ca9ffd3
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 10, 2023
b94c878
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 10, 2023
a8abd7c
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 11, 2023
9b344a2
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 11, 2023
93bc338
Fix unit tests
tjruwase Jan 13, 2023
fbde29c
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 13, 2023
3aead39
Resolve conflict
tjruwase Jan 13, 2023
dc7017f
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 14, 2023
4aaf788
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 17, 2023
5e2d556
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 17, 2023
4d47047
Merge branch 'master' into olruwase/cuda_optional_ops
tjruwase Jan 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: nv-torch-latest-cpu

on:
push:
branches:
- 'master'
- 'staging**'
paths-ignore:
- 'docs/**'
pull_request:
paths-ignore:
- 'docs/**'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unit-tests:
runs-on: ubuntu-20.04

steps:
- uses: actions/checkout@v2

- name: environment
run: |
echo "JobID: $AISC_NODE_INSTANCE_ID"
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
which python
python --version
pip install --upgrade pip
pip uninstall --yes torch torchvision triton
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
pip install torch==1.12.0+cpu torchvision==0.13.0+cpu torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"

- name: Install deepspeed
run: |
pip uninstall --yes deepspeed
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
pip install .[dev,autotuning]
ds_report

- name: Python environment
run: |
pip list

- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose -n 4 unit/ --torch_ver="1.12"
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose -m 'sequential' unit/ --torch_ver="1.12"
37 changes: 24 additions & 13 deletions csrc/adagrad/cpu_adagrad.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#include "cpu_adagrad.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#if defined(__ENABLE_CUDA__)
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
#endif

static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

Expand All @@ -20,7 +20,7 @@ void Adagrad_Optimizer::Step_1(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand All @@ -30,17 +30,19 @@ void Adagrad_Optimizer::Step_1(float* _params,
#endif
if (_param_size > rounded_size) {
float step_size = -1 * _alpha;
__half* grads_cast_h;
__half* params_cast_h;
ds_half_precision_t* grads_cast_h;
ds_half_precision_t* params_cast_h;
if (half_precision) {
grads_cast_h = reinterpret_cast<__half*>(grads);
params_cast_h = reinterpret_cast<__half*>(_params);
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
}
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
Expand All @@ -55,21 +57,24 @@ void Adagrad_Optimizer::Step_1(float* _params,
grad += _eps;
grad = momentum / grad;
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;

#endif
if (half_precision)
params_cast_h[k] = (__half)param;
params_cast_h[k] = (ds_half_precision_t)param;
else
_params[k] = param;
// STORE UPDATE TERM TO GRAD'S MEMORY
grads[k] = grad * step_size;
_exp_avg_sq[k] = variance;
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
#endif
}
}
}
Expand All @@ -78,7 +83,7 @@ void Adagrad_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand Down Expand Up @@ -130,7 +135,7 @@ void Adagrad_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand Down Expand Up @@ -170,7 +175,9 @@ int ds_adagrad_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay);
opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0));

#if defined(__ENABLE_CUDA__)
opt->SynchronizeStreams();
#endif
return 0;
}

Expand All @@ -184,14 +191,15 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

std::shared_ptr<Adagrad_Optimizer> opt =
Expand All @@ -206,6 +214,9 @@ int ds_adagrad_step_plus_copy(int optimizer_id,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
#else
assert(false);
#endif
return 0;
}

Expand Down
40 changes: 26 additions & 14 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <cassert>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>

#if defined(__ENABLE_CUDA__)
#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
#endif

static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;

Expand All @@ -21,7 +23,7 @@ void Adam_Optimizer::Step_1(float* _params,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand All @@ -41,19 +43,20 @@ void Adam_Optimizer::Step_1(float* _params,

float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half* grads_cast_h;
__half* params_cast_h;
ds_half_precision_t* grads_cast_h;
ds_half_precision_t* params_cast_h;
if (half_precision) {
grads_cast_h = reinterpret_cast<__half*>(grads);
params_cast_h = reinterpret_cast<__half*>(_params);
grads_cast_h = reinterpret_cast<ds_half_precision_t*>(grads);
params_cast_h = reinterpret_cast<ds_half_precision_t*>(_params);
}

for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }

#endif
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = half_precision ? (float)grads_cast_h[k] : grads[k];
Expand All @@ -73,21 +76,24 @@ void Adam_Optimizer::Step_1(float* _params,
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
#if defined(__ENABLE_CUDA__)
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;

#endif
if (half_precision)
params_cast_h[k] = (__half)param;
params_cast_h[k] = (ds_half_precision_t)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);

_buf_index = !_buf_index;
}
#endif
}
}
}
Expand All @@ -97,7 +103,7 @@ void Adam_Optimizer::Step_4(float* _params,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand Down Expand Up @@ -166,7 +172,7 @@ void Adam_Optimizer::Step_8(float* _params,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t rounded_size = 0;
Expand Down Expand Up @@ -228,7 +234,9 @@ int ds_adam_step(int optimizer_id,
nullptr,
(params.options().dtype() == at::kHalf));

#if defined(__ENABLE_CUDA__)
opt->SynchronizeStreams();
#endif
return 0;
}

Expand All @@ -246,6 +254,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
#if defined(__ENABLE_CUDA__)
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
Expand All @@ -254,7 +263,7 @@ int ds_adam_step_plus_copy(int optimizer_id,

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();

Expand All @@ -271,6 +280,9 @@ int ds_adam_step_plus_copy(int optimizer_id,
(params.options().dtype() == at::kHalf));

opt->SynchronizeStreams();
#else
assert(false);
#endif
return 0;
}

Expand Down
Loading