Skip to content

Commit

Permalink
Merge branch 'master' into staging-1bit-nccl-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
conglongli committed Dec 23, 2020
2 parents a6dba72 + 7435b2f commit 7840085
Show file tree
Hide file tree
Showing 28 changed files with 342 additions and 208 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/pre-compile-ops.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This is a basic workflow to help you get started with Actions

name: Tests-w-precompiled-ops

# Controls when the action will run.
on:
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: self-hosted

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Runs a single command using the runners shell
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
DS_BUILD_OPS=1 pip install .[dev]
ds_report
- name: Formatting checks
run: |
pre-commit run --all-files
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
13 changes: 10 additions & 3 deletions csrc/transformer/ds_transformer_cuda.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;

const int init_seq_length = 128;

// C++ interface

template <typename T>
Expand Down Expand Up @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id,
int hidden_dim,
int num_heads,
int intermediate_size,
int seq_length,
float attn_dropout_ratio,
float hidden_dropout_ratio,
int seed,
Expand All @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id,
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads);
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);

auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
seq_length,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
pre_or_postLayerNorm,
Expand Down Expand Up @@ -873,6 +874,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
}

auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
Expand Down
20 changes: 14 additions & 6 deletions csrc/transformer/softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
Expand Down Expand Up @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals,
#endif

int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);

for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }

Expand Down Expand Up @@ -339,7 +343,9 @@ void launch_attn_softmax<float>(float* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down Expand Up @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals,
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);

iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
attn_softmax<32, (threads / 128), 128><<<grid_dim, block_dim, 0, stream>>>(
vals, attn_mask, heads, seq_length4, iterations);
Expand Down
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import log_dist
from .utils.distributed import init_distributed

from .pipe import PipelineModule

Expand Down
8 changes: 8 additions & 0 deletions deepspeed/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
8 changes: 6 additions & 2 deletions deepspeed/git_version_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
# This is populated by setup.py
from .git_version_info_installed import *
except ModuleNotFoundError:
# Will be missing from checkouts that haven't been installed (e.g., readthedocs)
version = open('version.txt', 'r').read().strip()
import os
if os.path.isfile('version.txt'):
# Will be missing from checkouts that haven't been installed (e.g., readthedocs)
version = open('version.txt', 'r').read().strip()
else:
version = "0.0.0"
git_hash = '[none]'
git_branch = '[none]'

Expand Down
5 changes: 0 additions & 5 deletions deepspeed/launcher/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Copyright 2020 The Microsoft DeepSpeed Team

#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500

PDSH_LAUNCHER = 'pdsh'
PDSH_MAX_FAN_OUT = 1024

Expand Down
3 changes: 2 additions & 1 deletion deepspeed/launcher/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict
from argparse import ArgumentParser, REMAINDER

from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger


Expand Down Expand Up @@ -113,6 +113,7 @@ def main():
# each process's rank
dist_rank = global_rank_mapping[local_node][local_rank]
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)

# spawn the processes
cmd = [
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch.cuda

from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \
PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..utils import logger

DLTS_HOSTFILE = "/job/hostfile"
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/ops/sparse_attention/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ class Softmax:
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
"""

sparse_softmax = _sparse_softmax.apply
def sparse_softmax(*args, **kwargs):
return _sparse_softmax.apply(*args, **kwargs)

def make_lut(self, device):
"""Generates the sparsity layout used in block-sparse softmax
Expand Down
28 changes: 22 additions & 6 deletions deepspeed/ops/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
class TransformerConfig():
def __init__(self,
batch_size,
max_seq_length,
hidden_size,
intermediate_size,
heads,
Expand All @@ -30,7 +29,6 @@ def __init__(self,
self.batch_size = batch_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_seq_length = max_seq_length
self.heads = heads
self.attn_dropout_ratio = attn_dropout_ratio
self.hidden_dropout_ratio = hidden_dropout_ratio
Expand Down Expand Up @@ -92,7 +90,6 @@ class DeepSpeedTransformerConfig(TransformerConfig):
"""
def __init__(self,
batch_size=-1,
max_seq_length=-1,
hidden_size=-1,
intermediate_size=-1,
heads=-1,
Expand All @@ -112,7 +109,6 @@ def __init__(self,
super(DeepSpeedTransformerConfig,
self).__init__(
batch_size,
max_seq_length,
hidden_size,
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
heads,
Expand Down Expand Up @@ -142,7 +138,7 @@ def from_dict(cls, json_object):

@classmethod
def from_json_file(cls, json_file):
with open(json_file, "r", encoding='utf-8') as reader:
with open(json_file, "r", encoding='utf-16') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))

Expand Down Expand Up @@ -177,6 +173,18 @@ def forward(ctx,
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32

inp_size = input.size()
if inp_size[1] % 16 != 0:
input = torch.cat((input,
torch.randn((inp_size[0],
(16 - (inp_size[1] % 16)),
inp_size[2]),
device=input.device,
dtype=input.dtype)),
1)
input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \
(16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3)

(output,
inp_norm,
qkv_tf,
Expand Down Expand Up @@ -303,11 +311,17 @@ def forward(ctx,
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var

if inp_size[1] % 16 != 0:
output = torch.narrow(output, 1, 0, inp_size[1])
return output

@staticmethod
def backward(ctx, grad_output):
bsz = grad_output.shape[0]
grad_output_shape = grad_output.size()
if grad_output_shape[1] % 16 != 0:
grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \
grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1)

if bsz > ctx.config.batch_size:
raise ValueError('grad_output batch size exceeds the limit.')
Expand Down Expand Up @@ -398,6 +412,9 @@ def backward(ctx, grad_output):
norm_w,
norm_b)

if grad_output_shape[1] % 16 != 0:
grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])

return (grad_input,
None,
None,
Expand Down Expand Up @@ -501,7 +518,6 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None):
self.config.hidden_size,
self.config.heads,
self.config.intermediate_size,
self.config.max_seq_length,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.seed,
Expand Down
5 changes: 0 additions & 5 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False

#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = "29500"

# Steps
STEPS_PER_PRINT = "steps_per_print"
STEPS_PER_PRINT_DEFAULT = 10
Expand Down
Loading

0 comments on commit 7840085

Please sign in to comment.