Skip to content

Commit

Permalink
Reduce all-reduce call stack overhead (#2278)
Browse files Browse the repository at this point in the history
  • Loading branch information
liangan1 committed Nov 20, 2023
1 parent ed5957e commit 066c3bf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 23 deletions.
20 changes: 10 additions & 10 deletions csrc/cpu/jit/passes/graph_rewrite.cpp
Expand Up @@ -1313,46 +1313,46 @@ void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph) {

void simplifyAllReduce(std::shared_ptr<Graph>& graph) {
std::string all_reduce_v1 = R"(
graph(%a, %weight, %out_features1, %out_features2, %reduceop, %tag, %ranks, %group_size, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero):
graph(%a, %weight, %out_features1, %out_features2, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero):
%r1 = torch_ipex::tpp_linear(%a, %weight, %out_features1)
%r2 = deepspeed_comm::all_reduce(%r1, %reduceop, %tag, %ranks, %group_size)
%r2 = deepspeed_comm::all_reduce(%r1)
%r3 = torch_ipex::tpp_linear_gelu(%b, %fc_in_weight, %fc_in_bias, %out_features2)
%r4 = aten::to(%r3, %idx, %no, %no, %dtype)
%r5 = aten::contiguous(%r4, %zero)
%r6 = torch_ipex::tpp_linear(%r5, %fc_out_weight, %out_features1)
%r7 = deepspeed_comm::all_reduce(%r6, %reduceop, %tag, %ranks, %group_size)
%r7 = deepspeed_comm::all_reduce(%r6)
%r8 = aten::add_(%r7, %fc_out_bias, %alpha)
%r = aten::add(%r2, %r8, %alpha)
return (%r) )";
std::string all_reduce_repl_v1 = R"(
graph(%a, %weight, %out_features1, %out_features2, %reduceop, %tag, %ranks, %group_size, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero):
graph(%a, %weight, %out_features1, %out_features2, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero):
%r1 = torch_ipex::tpp_linear(%a, %weight, %out_features1)
%r2 = torch_ipex::tpp_linear_gelu(%b, %fc_in_weight, %fc_in_bias, %out_features2)
%r3 = aten::to(%r2, %idx, %no, %no, %dtype)
%r4 = aten::contiguous(%r3, %zero)
%r5 = torch_ipex::tpp_linear(%r4, %fc_out_weight, %out_features1)
%r6 = aten::add(%r1, %r5, %alpha)
%r7 = deepspeed_comm::all_reduce(%r6, %reduceop, %tag, %ranks, %group_size)
%r7 = deepspeed_comm::all_reduce(%r6)
%r = aten::add_(%r7, %fc_out_bias, %alpha)
return (%r) )";

std::string all_reduce_v2 = R"(
graph(%a, %weight, %reduceop, %tag, %ranks, %group_size, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha):
graph(%a, %weight, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha):
%r1 = ipex_prepack::linear_run(%a, %weight)
%r2 = deepspeed_comm::all_reduce(%r1, %reduceop, %tag, %ranks, %group_size)
%r2 = deepspeed_comm::all_reduce(%r1)
%r3 = ipex_prepack::linear_gelu_run(%b, %fc_in_weight, %fc_in_bias)
%r4 = ipex_prepack::linear_run(%r3, %fc_out_weight)
%r5 = deepspeed_comm::all_reduce(%r4, %reduceop, %tag, %ranks, %group_size)
%r5 = deepspeed_comm::all_reduce(%r4)
%r6 = aten::add_(%r5, %fc_out_bias, %alpha)
%r = aten::add(%r2, %r6, %alpha)
return (%r) )";
std::string all_reduce_repl_v2 = R"(
graph(%a, %weight, %reduceop, %tag, %ranks, %group_size, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha):
graph(%a, %weight, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha):
%r1 = ipex_prepack::linear_run(%a, %weight)
%r2 = ipex_prepack::linear_gelu_run(%b, %fc_in_weight, %fc_in_bias)
%r3 = ipex_prepack::linear_run(%r2, %fc_out_weight)
%r4 = aten::add(%r1, %r3, %alpha)
%r5 = deepspeed_comm::all_reduce(%r4, %reduceop, %tag, %ranks, %group_size)
%r5 = deepspeed_comm::all_reduce(%r4)
%r = aten::add_(%r5, %fc_out_bias, %alpha)
return (%r) )";

Expand Down
16 changes: 3 additions & 13 deletions intel_extension_for_pytorch/nn/utils/_weight_prepack.py
Expand Up @@ -2,7 +2,6 @@
import torch.nn as nn
import torch.nn.functional as F
import logging
import os
import pkg_resources
from intel_extension_for_pytorch import optim
from intel_extension_for_pytorch.cpu.tpp.utils.blocked_layout import (
Expand Down Expand Up @@ -94,28 +93,19 @@ def may_import_deepspeed_modules():
if "deepspeed" in installed_pkg:
from deepspeed import comm

def _all_reduce(self, reduceOp, tag, ranks, group_size):
def _all_reduce(self):
comm.inference_all_reduce(self, async_op=False)
return self

ds_comm = torch.library.Library("deepspeed_comm", "DEF")
ds_comm.define(
"all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor"
)
ds_comm.define("all_reduce(Tensor self) -> Tensor")
ds_comm_lib_cpu = torch.library.Library("deepspeed_comm", "IMPL", "CPU")
ds_comm_lib_cpu.impl("all_reduce", _all_reduce)


def _all_reduce_and_bias_add(mp_group, original_bias, output):
if mp_group is not None:
torch.ops.deepspeed_comm.all_reduce(
output,
"sum",
"",
list(torch.arange(int(os.environ["WORLD_SIZE"]))),
int(os.environ["WORLD_SIZE"]),
)

torch.ops.deepspeed_comm.all_reduce(output)
if original_bias is not None:
output += original_bias

Expand Down

0 comments on commit 066c3bf

Please sign in to comment.