diff --git a/.github/workflows/pre-compile-ops.yml b/.github/workflows/pre-compile-ops.yml new file mode 100644 index 000000000000..4005d4baf2fc --- /dev/null +++ b/.github/workflows/pre-compile-ops.yml @@ -0,0 +1,47 @@ +# This is a basic workflow to help you get started with Actions + +name: Tests-w-precompiled-ops + +# Controls when the action will run. +on: + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + # This workflow contains a single job called "build" + build: + # The type of runner that the job will run on + runs-on: self-hosted + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + # Runs a single command using the runners shell + - name: environment + run: | + nvidia-smi + which python + python --version + which nvcc + nvcc --version + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + # Runs a set of commands using the runners shell + - name: Install deepspeed + run: | + DS_BUILD_OPS=1 pip install .[dev] + ds_report + + - name: Formatting checks + run: | + pre-commit run --all-files + + # Runs a set of commands using the runners shell + - name: Unit tests + run: | + if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi + TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/ diff --git a/DeepSpeedExamples b/DeepSpeedExamples index fa1d1a71c486..78d69cb2f89a 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7 +Subproject commit 78d69cb2f89a27b1e9b072df8c3e47d00c024fdc diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100644 new mode 100755 index 85ec0418971c..ebd534d04ab3 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -14,6 +14,8 @@ static std::unordered_map> s_transformer_layers; +const int init_seq_length = 128; + // C++ interface template @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id, int hidden_dim, int num_heads, int intermediate_size, - int seq_length, float attn_dropout_ratio, float hidden_dropout_ratio, int seed, @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id, { Context::Instance().SetSeed(seed); Context::Instance().TestGemmFP16( - test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads); + test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); auto layer = std::make_shared>(layer_id, batch_size, hidden_dim, num_heads, intermediate_size, - seq_length, + init_seq_length, attn_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm, @@ -873,6 +874,12 @@ std::vector ds_transformer_backward(int layer_id, std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); + int seq_len = layer->GetSeqLength(); + if (g_output.size(1) != seq_len) { + seq_len = g_output.size(1); + layer->SetSeqLength(seq_len, bsz); + } + auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvb = torch::empty_like(attn_qkvb); diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 582da4829f47..be776b0c074d 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -339,7 +343,9 @@ void launch_attn_softmax(float* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 8ac0aad05562..ba6f9b5bb6bf 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -14,6 +14,7 @@ from .runtime.activation_checkpointing import checkpointing from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .utils import log_dist +from .utils.distributed import init_distributed from .pipe import PipelineModule diff --git a/deepspeed/constants.py b/deepspeed/constants.py new file mode 100644 index 000000000000..467e85aefcb6 --- /dev/null +++ b/deepspeed/constants.py @@ -0,0 +1,8 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +############################################# +# Torch distributed constants +############################################# +TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index d17948ae41a7..f04982c74f0d 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -2,8 +2,12 @@ # This is populated by setup.py from .git_version_info_installed import * except ModuleNotFoundError: - # Will be missing from checkouts that haven't been installed (e.g., readthedocs) - version = open('version.txt', 'r').read().strip() + import os + if os.path.isfile('version.txt'): + # Will be missing from checkouts that haven't been installed (e.g., readthedocs) + version = open('version.txt', 'r').read().strip() + else: + version = "0.0.0" git_hash = '[none]' git_branch = '[none]' diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index f384d58b2c52..fd56facc4343 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -1,10 +1,5 @@ # Copyright 2020 The Microsoft DeepSpeed Team -############################################# -# Torch distributed constants -############################################# -TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 - PDSH_LAUNCHER = 'pdsh' PDSH_MAX_FAN_OUT = 1024 diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py index 205aee2d6ac4..0958295efe06 100755 --- a/deepspeed/launcher/launch.py +++ b/deepspeed/launcher/launch.py @@ -16,7 +16,7 @@ from collections import defaultdict from argparse import ArgumentParser, REMAINDER -from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..utils import logger @@ -113,6 +113,7 @@ def main(): # each process's rank dist_rank = global_rank_mapping[local_node][local_rank] current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes cmd = [ diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 9479bb63758c..eb03502cc3f2 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -19,8 +19,8 @@ import torch.cuda from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner -from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT, \ - PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER +from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..utils import logger DLTS_HOSTFILE = "/job/hostfile" diff --git a/deepspeed/ops/sparse_attention/softmax.py b/deepspeed/ops/sparse_attention/softmax.py index cd18fbcae71f..a0805ada4bc0 100644 --- a/deepspeed/ops/sparse_attention/softmax.py +++ b/deepspeed/ops/sparse_attention/softmax.py @@ -224,8 +224,8 @@ class Softmax: For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509 """ - - sparse_softmax = _sparse_softmax.apply + def sparse_softmax(*args, **kwargs): + return _sparse_softmax.apply(*args, **kwargs) def make_lut(self, device): """Generates the sparsity layout used in block-sparse softmax diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index a91e5ce6f08b..ea4b98848d3c 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -18,7 +18,6 @@ class TransformerConfig(): def __init__(self, batch_size, - max_seq_length, hidden_size, intermediate_size, heads, @@ -30,7 +29,6 @@ def __init__(self, self.batch_size = batch_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.max_seq_length = max_seq_length self.heads = heads self.attn_dropout_ratio = attn_dropout_ratio self.hidden_dropout_ratio = hidden_dropout_ratio @@ -92,7 +90,6 @@ class DeepSpeedTransformerConfig(TransformerConfig): """ def __init__(self, batch_size=-1, - max_seq_length=-1, hidden_size=-1, intermediate_size=-1, heads=-1, @@ -112,7 +109,6 @@ def __init__(self, super(DeepSpeedTransformerConfig, self).__init__( batch_size, - max_seq_length, hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, @@ -142,7 +138,7 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): - with open(json_file, "r", encoding='utf-8') as reader: + with open(json_file, "r", encoding='utf-16') as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -177,6 +173,18 @@ def forward(ctx, cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32 + inp_size = input.size() + if inp_size[1] % 16 != 0: + input = torch.cat((input, + torch.randn((inp_size[0], + (16 - (inp_size[1] % 16)), + inp_size[2]), + device=input.device, + dtype=input.dtype)), + 1) + input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \ + (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3) + (output, inp_norm, qkv_tf, @@ -303,11 +311,17 @@ def forward(ctx, ctx.attn_layer_norm_var = attn_layer_norm_var ctx.layer_norm_var = layer_norm_var + if inp_size[1] % 16 != 0: + output = torch.narrow(output, 1, 0, inp_size[1]) return output @staticmethod def backward(ctx, grad_output): bsz = grad_output.shape[0] + grad_output_shape = grad_output.size() + if grad_output_shape[1] % 16 != 0: + grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \ + grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1) if bsz > ctx.config.batch_size: raise ValueError('grad_output batch size exceeds the limit.') @@ -398,6 +412,9 @@ def backward(ctx, grad_output): norm_w, norm_b) + if grad_output_shape[1] % 16 != 0: + grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) + return (grad_input, None, None, @@ -501,7 +518,6 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): self.config.hidden_size, self.config.heads, self.config.intermediate_size, - self.config.max_seq_length, self.config.attn_dropout_ratio, self.config.hidden_dropout_ratio, self.config.seed, diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index a731865714fe..c56c3898f60f 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -73,11 +73,6 @@ ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False -############################################# -# Torch distributed constants -############################################# -TORCH_DISTRIBUTED_DEFAULT_PORT = "29500" - # Steps STEPS_PER_PRINT = "steps_per_print" STEPS_PER_PRINT_DEFAULT = 10 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7c9b920d8bb6..8b2901f8452e 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,12 +24,12 @@ from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA + PLD_THETA, PLD_GAMMA from deepspeed.runtime.zero.constants import \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.csr_tensor import CSRTensor import deepspeed.runtime.lr_schedules as lr_schedules -from deepspeed.utils import logger, log_dist +from deepspeed.utils import logger, log_dist, init_distributed from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop @@ -130,29 +130,14 @@ def __init__(self, if dist_init_required is False: assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()" - # DeepSpeed will initialize torch distributed only if the user has not already intialized it. - if dist_init_required and not dist.is_initialized(): - # discover using mpi4py if user specifies the flag - if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: - # if in Azure ML environment and user specified this flag, notify the user to remove the flag. - if self._in_aml(): - logger.warning( - "Please remove the --deepspeed_mpi flag if running on AzureML.") - self._mpi_check(args, dist_init_required) - else: - # detect if we are in Azure ML environment - if self._in_aml(): - self._set_environment_variables_for_nccl_backend(args) - - logger.info("Initializing torch distributed with backend: {}".format( - self.dist_backend)) - dist.init_process_group(backend=self.dist_backend) + # Initialize torch distributed if needed + init_distributed(dist_backend=self.dist_backend) self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() - self._init_distributed(dist_init_required) + self._set_distributed_vars() if self.tensorboard_enabled() and self.global_rank == 0: self.summary_writer = self.get_summary_writer() @@ -209,87 +194,6 @@ def __init__(self, self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten - def _in_aml(self): - # read AzureML environment variable to detect if we are using an Azure ML environment - if 'AZUREML_EXPERIMENT_ID' in os.environ: - return True - else: - return False - - def _set_environment_variables_for_nccl_backend(self, - args, - master_port=6105, - verbose=True): - """Helper routine to get and set environment variables. - This is adapted from Azure ML's documentation available from: - https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi - """ - os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] - os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] - single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( - os.environ["WORLD_SIZE"]) - if not single_node: - master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") - os.environ["MASTER_ADDR"] = master_node_params[0] - # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE - if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"] = str(master_port) - else: - os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] - os.environ["MASTER_PORT"] = "54965" - print("NCCL_SOCKET_IFNAME original value = {}".format( - os.environ["NCCL_SOCKET_IFNAME"])) - - os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" - args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) - - if verbose: - logger.info( - "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" - .format(os.environ['RANK'], - args.local_rank, - os.environ['WORLD_SIZE'], - os.environ['MASTER_ADDR'], - os.environ['MASTER_PORT'])) - - def _mpi_check(self, args, dist_init_required): - from mpi4py import MPI - import subprocess - comm = MPI.COMM_WORLD - rank = comm.Get_rank() - world_size = comm.Get_size() - - master_addr = None - if rank == 0: - hostname_cmd = ["hostname -I"] - result = subprocess.check_output(hostname_cmd, shell=True) - master_addr = result.decode('utf-8').split()[0] - master_addr = comm.bcast(master_addr, root=0) - - # Determine local rank by assuming hostnames are unique - proc_name = MPI.Get_processor_name() - all_procs = comm.allgather(proc_name) - local_rank = sum([i == proc_name for i in all_procs[:rank]]) - - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - args.local_rank = local_rank - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT - - logger.info( - "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" - .format(os.environ['RANK'], - args.local_rank, - os.environ['WORLD_SIZE'], - os.environ['MASTER_ADDR'], - os.environ['MASTER_PORT'])) - - if not dist_init_required and dist.is_initialized(): - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( - world_size, dist.get_world_size()) - def pld_enabled(self): return self._config.pld_enabled @@ -497,7 +401,7 @@ def _scheduler_from_config(self, optimizer): else: return None - def _init_distributed(self, dist_init_required): + def _set_distributed_vars(self): if self.local_rank >= 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) @@ -979,7 +883,7 @@ def clip_fp32_gradients(self): torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping()) - def _take_model_step(self): + def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() @@ -1010,14 +914,14 @@ def _take_model_step(self): self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step() + self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.global_steps += 1 self.global_samples += self.train_batch_size() - def step(self): + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ @@ -1034,7 +938,7 @@ def step(self): if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - self._take_model_step() + self._take_model_step(lr_kwargs) self.tput_timer.stop(report_progress) diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 5ec106c28d67..e7e3be1e786b 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -381,6 +381,12 @@ def get_lr(self): lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def _update_optimizer(self, group_lrs): for param_group, lr in zip(self.optimizer.param_groups, group_lrs): param_group['lr'] = lr @@ -390,6 +396,7 @@ def step(self, batch_iteration=None): batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration self._update_optimizer(self.get_lr()) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -628,12 +635,19 @@ def get_lr(self): return self._get_cycle_lr() return self._get_decay_lr(self.last_batch_iteration - self.total_size) + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def step(self, batch_iteration=None): if batch_iteration is None: batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} @@ -690,12 +704,19 @@ def get_lr(self): self.delta_lrs) ] + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + def step(self, last_batch_iteration=None): if last_batch_iteration is None: last_batch_iteration = self.last_batch_iteration + 1 self.last_batch_iteration = last_batch_iteration for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] def state_dict(self): return {'last_batch_iteration': self.last_batch_iteration} diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 954774e58912..5c5d896dfc0d 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -940,14 +940,14 @@ def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() - def _exec_optimizer_step(self): + def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): self.timers('step_microstep').start() self.timers('step').start() self.mem_status('BEFORE STEP', reset_max=True) self._force_grad_boundary = True - self._take_model_step() + self._take_model_step(lr_kwargs) self._force_grad_boundary = False self.mem_status('AFTER STEP') diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 14bfc937705c..b784f3ffdd6c 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,6 +6,7 @@ from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.utils import logger from deepspeed.runtime.zero.constants import * +import json class DeepSpeedZeroConfig(object): @@ -54,6 +55,9 @@ def read_zero_config_deprecated(self, param_dict): def repr(self): return self.__dict__ + def __repr__(self): + return json.dumps(self.__dict__, sort_keys=True, indent=4) + def _initialize(self, zero_config_dict): self.stage = get_scalar_param(zero_config_dict, ZERO_OPTIMIZATION_STAGE, diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 37517764b375..c231edca4919 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,2 +1,3 @@ -from deepspeed.utils.logging import logger, log_dist +from .logging import logger, log_dist +from .distributed import init_distributed from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py new file mode 100644 index 000000000000..e70f00b440bb --- /dev/null +++ b/deepspeed/utils/distributed.py @@ -0,0 +1,129 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import os +import torch + +from .logging import logger +from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT + + +def init_distributed(dist_backend="nccl", + auto_mpi_discovery=True, + distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, + verbose=True): + """ + Initialize torch.distributed backend, potentially performing MPI discovery if needed + Arguments: + dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo + auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI + distributed_port (int, optional): torch distributed backend port + verbose (bool, optional): verbose logging + """ + + required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + if verbose: + logger.info( + "Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..." + ) + if in_aml() and not in_dlts(): + patch_aml_env_for_torch_nccl_backend(verbose=verbose) + else: + mpi_discovery(distributed_port=distributed_port, verbose=verbose) + + if not torch.distributed.is_initialized(): + if verbose: + logger.info( + "Initializing torch distributed with backend: {}".format(dist_backend)) + torch.distributed.init_process_group(backend=dist_backend) + + +def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True): + """ + Discovery MPI environment via mpi4py and map to relevant torch.distributed state + """ + from mpi4py import MPI + import subprocess + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + master_addr = None + if rank == 0: + hostname_cmd = ["hostname -I"] + result = subprocess.check_output(hostname_cmd, shell=True) + master_addr = result.decode('utf-8').split()[0] + master_addr = comm.bcast(master_addr, root=0) + + # Determine local rank by assuming hostnames are unique + proc_name = MPI.Get_processor_name() + all_procs = comm.allgather(proc_name) + local_rank = sum([i == proc_name for i in all_procs[:rank]]) + + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(local_rank) + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = str(distributed_port) + + if verbose: + logger.info( + "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" + .format(os.environ['RANK'], + os.environ['LOCAL_RANK'], + os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], + os.environ['MASTER_PORT'])) + + if torch.distributed.is_initialized(): + assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) + assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( + world_size, dist.get_world_size()) + + +def in_aml(): + # Are we running inside an Azure Machine Learning (AML) environment? + return 'AZUREML_EXPERIMENT_ID' in os.environ + + +def in_dlts(): + # Are we running on a DLTS cluster? + return 'DLTS_JOB_ID' in os.environ + + +def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True): + """Helper routine to get and set environment variables. + This is adapted from Azure ML's documentation available from: + https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi + """ + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( + os.environ["WORLD_SIZE"]) + + if not single_node: + master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") + os.environ["MASTER_ADDR"] = master_node_params[0] + # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(master_port) + else: + os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] + os.environ["MASTER_PORT"] = "54965" + + if verbose: + logger.info("NCCL_SOCKET_IFNAME original value = {}".format( + os.environ["NCCL_SOCKET_IFNAME"])) + + os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" + os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] + + if verbose: + logger.info( + "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" + .format(os.environ['RANK'], + os.environ['LOCAL_RANK'], + os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], + os.environ['MASTER_PORT'])) diff --git a/docs/_pages/features.md b/docs/_pages/features.md index ec0724e11aa4..2074bb3e3b0f 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -28,7 +28,8 @@ deepspeed --hostfile= \ \ --deepspeed --deepspeed_config ds_config.json ``` -The script `` will execute on the resources specified in ``. +The script `` will execute on the resources specified in +[``](/getting-started/#resource-configuration-multi-node). ## Pipeline Parallelism DeepSpeed provides [pipeline parallelism](/tutorials/pipeline/) for memory- diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index 1f23c64d4085..21268802d6c8 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -216,25 +216,27 @@ DeepSpeed will then make sure that these environment variables are set when launching each process on every node across their training job. -### MPI Compatibility +### MPI and AzureML Compatibility As described above, DeepSpeed provides its own parallel launcher to help launch multi-node/multi-gpu training jobs. If you prefer to launch your training job using MPI (e.g., mpirun), we provide support for this. It should be noted that DeepSpeed will still use the torch distributed NCCL backend and *not* the MPI -backend. To launch your training job with mpirun + DeepSpeed you simply pass us -an additional flag `--deepspeed_mpi`. DeepSpeed will then use -[mpi4py](https://pypi.org/project/mpi4py/) to discover the MPI environment (e.g., -rank, world size) and properly initialize torch distributed for training. In this -case you will explicitly invoke `python` to launch your model script instead of using -the `deepspeed` launcher, here is an example: -```bash -mpirun python \ - \ - --deepspeed_mpi --deepspeed --deepspeed_config ds_config.json -``` +backend. + +To launch your training job with mpirun + DeepSpeed or with AzureML (which uses +mpirun as a launcher backend) you simply need to install the +[mpi4py](https://pypi.org/project/mpi4py/) python package. DeepSpeed will use +this to discover the MPI environment and pass the necessary state (e.g., world +size, rank) to the torch distributed backend. -If you want to use this feature of DeepSpeed, please ensure that mpi4py is -installed via `pip install mpi4py`. +If you are using model parallelism, pipeline parallelism, or otherwise require +torch.distributed calls before calling `deepspeed.initialize(..)` we provide +the same MPI support with an additional DeepSpeed API call. Replace your initial +`torch.distributed.init_process_group(..)` call with: + +```python +deepspeed.init_distributed() +``` ## Resource Configuration (single-node) In the case that we are only running on a single node (with one or more GPUs) diff --git a/docs/code-docs/source/conf.py b/docs/code-docs/source/conf.py index 167f6427d7b4..eb9a412d8a4a 100644 --- a/docs/code-docs/source/conf.py +++ b/docs/code-docs/source/conf.py @@ -79,4 +79,4 @@ autoclass_content = 'both' -autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX", "numpy"] +autodoc_mock_imports = ["torch", "apex", "mpi4py", "tensorboardX", "numpy", "cupy"] diff --git a/install.sh b/install.sh index b027d319cdd6..b9f1501d9cad 100755 --- a/install.sh +++ b/install.sh @@ -171,5 +171,5 @@ else pdcp -w $hosts dist/deepspeed*.whl $tmp_wheel_path/ pdsh -w $hosts "$PIP_SUDO $PIP_INSTALL $tmp_wheel_path/deepspeed*.whl" pdsh -w $hosts "ds_report" - pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rmdir $tmp_wheel_path; fi" + pdsh -w $hosts "if [ -d $tmp_wheel_path ]; then rm $tmp_wheel_path/*.whl; rm $tmp_wheel_path/*.txt; rmdir $tmp_wheel_path; fi" fi diff --git a/op_builder/builder.py b/op_builder/builder.py index f44aee79637a..1f350065b4f6 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -221,7 +221,7 @@ def compute_capability_args(self, cross_compile_archs=None): 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used. + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX Format: @@ -243,6 +243,7 @@ def compute_capability_args(self, cross_compile_archs=None): if cc not in ccs: ccs.append(cc) ccs = sorted(ccs) + ccs[-1] += '+PTX' else: # Cross-compile mode, compile for various architectures # env override takes priority @@ -260,8 +261,10 @@ def compute_capability_args(self, cross_compile_archs=None): args = [] for cc in ccs: - cc = cc.replace('.', '') - args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}') + num = cc[0] + cc[2] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if cc.endswith('+PTX'): + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') return args diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index c032a8c9fdad..78620c472c9d 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -1,2 +1 @@ tqdm -psutil diff --git a/tests/unit/common.py b/tests/unit/common.py index 73d7957e29f9..62b7495a025c 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -5,6 +5,8 @@ import torch.distributed as dist from torch.multiprocessing import Process +import deepspeed + import pytest # Worker timeout *after* the first worker has completed. @@ -33,10 +35,12 @@ def dist_init(local_rank, num_procs, *func_args, **func_kwargs): """Initialize torch.distributed and execute the user function. """ os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29503' - dist.init_process_group(backend=backend, - init_method='env://', - rank=local_rank, - world_size=num_procs) + os.environ['LOCAL_RANK'] = str(local_rank) + # NOTE: unit tests don't support multi-node so local_rank == global rank + os.environ['RANK'] = str(local_rank) + os.environ['WORLD_SIZE'] = str(num_procs) + + deepspeed.init_distributed(dist_backend=backend) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 317cd7aa33c0..fd3f9887ad42 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -150,7 +150,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range) @@ -210,25 +210,18 @@ def set_seed(seed): torch.manual_seed(seed) -def run_backward(ds_config, atol=1e-2, verbose=False): +def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): set_seed(123) bert_encoder, ds_encoder = create_models(ds_config) # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 hidden_states = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, + seq_len, ds_config.hidden_size, **kwargs) - input_mask = torch.randn(ds_config.batch_size, - 1, - 1, - ds_config.max_seq_length, - **kwargs) - Y = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) + input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs) + Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -257,12 +250,12 @@ def run_backward(ds_config, atol=1e-2, verbose=False): #test_backward[3-1024-120-16-24-True-True-0.05] @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ - (3,1024,120,16,24,True,False, 0.05), - (3,1024,120,16,24,True,True, 0.05), - (3,1024,56,16,24,False,False, 0.1), - (3,1024,56,16,24,False,True, 0.2), - (3,128,56,2,24,False,False, 0.1), - (3,128,56,2,24,False,True, 0.2), + (3,1024,119,16,24,True,False, 0.05), + (3,1024,115,16,24,True,True, 0.05), + (1024,128,10,2,2,False,False, 0.1), + (3,1024,52,16,24,False,True, 0.2), + (3,128,51,2,24,False,False, 0.1), + (3,128,54,2,24,False,True, 0.2), ]) # yapf: disable def test_backward(batch_size, hidden_size, @@ -282,7 +275,6 @@ def test_backward(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -291,7 +283,7 @@ def test_backward(batch_size, ds_config.initializer_range = 0.02 ds_config.fp16 = use_fp16 - run_backward(ds_config, atol=atol) + run_backward(ds_config, seq_len, atol=atol) #@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 893b66c904bb..88cb90848603 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -117,7 +117,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range, fp16=ds_config.fp16) @@ -186,13 +186,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 - hidden_states = torch.randn(bsz, - seq_len, #ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) - input_mask = torch.randn(bsz, 1, 1, - seq_len, #ds_config.max_seq_length, - **kwargs) + hidden_states = torch.randn(bsz, seq_len, ds_config.hidden_size, **kwargs) + input_mask = torch.randn(bsz, 1, 1, seq_len, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -213,25 +208,25 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # FP16 test cases can only run on the devices support FP16. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', [ - (8,256,128,4,3,True,False), - (8,256,128,4,3,True,True), - (64,1024,128,16,3,True,False), - (64,1024,128,16,3,True,True), - (8,1024,384,16,3,True,False), + (8,256,53,4,3,True,False), + (8,256,52,4,3,True,True), + (3,1024,51,16,3,True,False), + (3,1024,54,16,3,True,True), + (8,1024,381,16,3,True,False), (8,1024,384,16,3,True,True), (8,1024,384,16,3,True,True), - (8,1024,120,16,3,True,False), + (8,1024,119,16,3,True,False), (8,1024,120,16,3,True,True), - (8,1024,512,16,3,True,False), + (8,1024,509,16,3,True,False), (8,1024,512,16,3,True,True), (64,1024,56,16,3,False,False), - (64,1024,56,16,3,False,True), + (64,1024,53,16,3,False,True), (64,1024,24,16,3,False,False), - (64,1024,24,16,3,False,True), + (64,1024,21,16,3,False,True), (8,1024,384,16,3,False,False), (8,1024,384,16,3,False,True), (8,1024,512,16,3,False,False), - (8,1024,512,16,3,False,True), + (8,1024,511,16,3,False,True), (8,1536,128,24,3,False,False), (8,1536,128,24,3,False,True), (8,2048,128,32,3,False,False), @@ -259,7 +254,6 @@ def test_forward(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.max_seq_length = 128 #seq_len ds_config.intermediate_size = 4 * hidden_size ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 @@ -297,7 +291,6 @@ def test_forward_with_small_bsz(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -332,7 +325,6 @@ def test_forward_stochastic(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0