Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions deepspeed/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team

from typing import Dict, List, Callable, Tuple
from typing import Dict, List, Callable, Tuple, Set
import time
import gc
from collections import OrderedDict, deque
Expand All @@ -12,7 +12,6 @@
from torch.fx import Graph, GraphModule

try:
import torch.utils._pytree as pytree
import torch._dynamo
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
Expand Down Expand Up @@ -47,11 +46,19 @@ class GraphOrder:
def __init__(self):
self.frames = OrderedDict()

def add_graph(self, graph_id: int, frame_id: int, needs_backward: bool):
def __len__(self):
return len(self.frames)

def add_graph(self, graph_id: int, frame_id: int):
if frame_id not in self.frames:
self.frames[frame_id] = (graph_id, needs_backward)
self.frames[frame_id] = (graph_id, None)

def set_needs_backward(self, frame_id: int, needs_backward: bool):
if frame_id in self.frames:
self.frames[frame_id] = (self.frames[frame_id][0], needs_backward)

def get_graph_order(self) -> List[Tuple[int, bool]]:
assert all(isinstance(needs_backward, bool) for _, needs_backward in self.frames.values())
return list(self.frames.values())

def clear(self):
Expand All @@ -61,6 +68,7 @@ def clear(self):
graph_order_with_frame_id = GraphOrder()

frames_needing_bwd = set()
frames_partitioned: Set[int] = set()
profiling_results: Dict[int, ProfilingResult] = {}
opt_pass_times = []
opt_passes = {}
Expand Down Expand Up @@ -96,6 +104,7 @@ def launch_compile_passes(global_steps: int):
graph_order_with_frame_id.clear()
profiling_results.clear()
param_manager.clear()
frames_partitioned.clear()


def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results):
Expand Down Expand Up @@ -225,11 +234,16 @@ def make_backend(backend, compile_config, compile_kwargs={}):
def backend_fn(gm: GraphModule, real_inputs):
graph_id = id(gm.graph)

needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs)
# Checking the existence of input tensors requiring grad alone is insufficient to determine `need_backward`.
# AOT autograd also checks the graph data flow and skips the backward pass if no output requires grad and no
# input requiring grad is mutated.
#
# Instead of replicating AOT autograd's backward pass determination (which is too costly), we infer whether
# backward pass is needed by checking if the joint graph is partitioned (into a forward and a backward module).
# This check cannot be placed here because autograd creates the fw/bw compiler callables before graph
# partitioning. It is thus postponed to the point where the fw compiler is called.
frame_id = gm.meta["dynamo_compile_id"].frame_id
graph_order_with_frame_id.add_graph(graph_id, frame_id, needs_backward)

graph_order = graph_order_with_frame_id.get_graph_order()
graph_order_with_frame_id.add_graph(graph_id, frame_id)

z3_partition = any(hasattr(v, "ds_id") for v in real_inputs)
if z3_partition:
Expand Down Expand Up @@ -258,11 +272,14 @@ def backend_fn(gm: GraphModule, real_inputs):
if graph_id not in profiling_results:
profiling_results[graph_id] = ProfilingResult()
profiling_results[graph_id].param_indices = param_indices
profiling_results[graph_id].needs_backward = needs_backward

def make_fw_graph(gm, sample_inputs):
time_start = time.time()
graph_index = len(graph_order) - 1
graph_index = len(graph_order_with_frame_id) - 1

needs_backward = frame_id in frames_partitioned
graph_order_with_frame_id.set_needs_backward(frame_id, needs_backward)
profiling_results[graph_id].needs_backward = needs_backward

if needs_backward:
if len(frames_needing_bwd) == 0:
Expand Down Expand Up @@ -290,7 +307,7 @@ def make_fw_graph(gm, sample_inputs):
opt_passes=next_passes,
gm=gm,
graph_id=graph_id,
graph_order=graph_order,
graph_order=graph_order_with_frame_id.get_graph_order(),
profiling_results=profiling_results,
create_inputs_fn=lambda: real_inputs_with_rng,
mem_budget=.0, # unused
Expand All @@ -308,6 +325,7 @@ def make_fw_graph(gm, sample_inputs):
def make_bw_graph(gm, sample_inputs):
time_start = time.time()

graph_order = graph_order_with_frame_id.get_graph_order()
graph_index = get_index_by_graph_id(graph_order, graph_id)
log_rank0(
f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
Expand Down Expand Up @@ -368,7 +386,8 @@ def compiler_fn(gm, sample_inputs):

return compiler_fn

partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition)
partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition,
frame_id, frames_partitioned)
aot_mod = aot_module_simplified(gm,
real_inputs,
fw_compiler=make_compiler_fn(make_fw_graph),
Expand All @@ -377,7 +396,7 @@ def compiler_fn(gm, sample_inputs):
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
elif backend == "inductor":
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager)
param_indices, param_manager, frame_id, frames_partitioned)

return torch._inductor.compile(gm, real_inputs)

Expand Down
34 changes: 21 additions & 13 deletions deepspeed/compile/inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# DeepSpeed Team

from typing import Set

import torch

try:
Expand Down Expand Up @@ -60,31 +62,37 @@ def wrapped_compiler(gm, fake_inputs):
return wrapped_compiler


def wrap_partition_fn(partition_fn, real_inputs, param_indices):
def wrap_partition_fn(z3_partition: bool, partition_fn, real_inputs, param_indices, frame_id: int,
frames_partitioned: Set[int]):

def wrapped_partition_fn(*args, **kwargs):

fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn)
fn = get_wrapped_partitioner(z3_partition,
param_indices,
partition_fn=partition_fn,
frame_id=frame_id,
frames_partitioned=frames_partitioned)
fw_module, bw_module = fn(*args, **kwargs)

# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
if z3_partition:
# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)

def fix_placeholder_meta(graph):
for n in graph.nodes:
if n.op == "placeholder" and n.name in pm.param_names:
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
def fix_placeholder_meta(graph):
for n in graph.nodes:
if n.op == "placeholder" and n.name in pm.param_names:
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)

fix_placeholder_meta(fw_module.graph)
fix_placeholder_meta(bw_module.graph)
fix_placeholder_meta(fw_module.graph)
fix_placeholder_meta(bw_module.graph)

return fw_module, bw_module

return wrapped_partition_fn


def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager):
param_indices, param_manager, frame_id: int, frames_partitioned: Set[int]):

from torch._dynamo.backends.common import AotAutograd
import functools
Expand Down Expand Up @@ -112,8 +120,8 @@ def patched_init(self, **kwargs):
bwd=True)
kwargs["inference_compiler"] = kwargs["fw_compiler"]

if z3_partition:
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
kwargs["partition_fn"] = wrap_partition_fn(z3_partition, kwargs["partition_fn"], real_inputs,
param_indices, frame_id, frames_partitioned)

original_init(self, **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion deepspeed/compile/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team

from typing import Tuple, List
from typing import Tuple, List, Set

import torch
from torch.fx import GraphModule, Graph, Node
Expand Down Expand Up @@ -85,10 +85,13 @@ def get_wrapped_partitioner(
z3_partition: bool,
param_indices: List[Tuple[int, int, torch.Size]],
partition_fn,
frame_id: int,
frames_partitioned: Set[int],
):

def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *, num_fwd_outputs,
**kwargs) -> Tuple[GraphModule, GraphModule]:
frames_partitioned.add(frame_id)
if z3_partition:
_recompute_param_aliases(joint_module.graph, param_indices)
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs)
Expand Down
Loading