Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug #1710

Merged
merged 20 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 0 additions & 7 deletions colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

Expand Down
9 changes: 8 additions & 1 deletion colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .tracer import ColoTracer, meta_trace
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
from .tracer import ColoTracer, meta_trace
40 changes: 40 additions & 0 deletions colossalai/fx/_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from . import META_COMPATIBILITY
super-dainiu marked this conversation as resolved.
Show resolved Hide resolved


def compatibility(is_backward_compatible: bool = False):
"""A decorator to make a function compatible with different versions of PyTorch.

Args:
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.

Returns:
Callable: The decorated function
"""

def decorator(func):
if META_COMPATIBILITY:
return func
else:
if is_backward_compatible:
return func
else:

def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')

return wrapper

return decorator
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved


def check_meta_compatibility():
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
experimental counterparts.

Returns:
bool: The meta compatibility
"""
return META_COMPATIBILITY
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
# for more meta_registrations

from typing import List, Optional, Tuple, Union

import torch
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -31,6 +34,7 @@ def add_func(op):
return wrapper


# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
Expand Down Expand Up @@ -165,6 +169,18 @@ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: t
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return grad_input
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
@register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor):
return torch.empty_like(input)
Expand Down Expand Up @@ -192,11 +208,8 @@ def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val:
return grad_in


@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input


# ============================== Normalization =====================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
Expand All @@ -207,6 +220,7 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
return output, running_mean, running_var


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
Expand Down Expand Up @@ -241,6 +255,7 @@ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.
return dX, dgamma, dbeta


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs = input.size(0)
Expand All @@ -252,6 +267,7 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
return output, running_mean, running_var


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
Expand All @@ -261,13 +277,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta


@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)
# ================================== Misc ==========================================
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)


@register_meta(aten.index.Tensor)
Expand Down Expand Up @@ -360,6 +381,8 @@ def meta_index_Tensor(self, indices):
return self.new_empty(before_shape + replacement_shape + after_shape)


# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
Expand All @@ -369,13 +392,7 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens
layout=grad_output.layout)


# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type)


# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

INF = float("inf")

Expand Down
11 changes: 5 additions & 6 deletions colossalai/fx/passes/concrete_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from dataclasses import asdict
from colossalai.fx.profiler import GraphInfo
from typing import Any, Dict, List, NamedTuple, Optional, Tuple

import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils._pytree import tree_flatten
from typing import Any, List, Tuple, NamedTuple, Dict, Optional
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size
from torch.fx.graph_module import GraphModule


@compatibility(is_backward_compatible=True)
Expand Down
10 changes: 6 additions & 4 deletions colossalai/fx/passes/meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple

import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from torch.utils._pytree import tree_map
from typing import Any, List, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_module, profile_method, activation_size, calculate_fwd_out, calculate_fwd_tmp, calculate_fwd_in


@compatibility(is_backward_compatible=True)
Expand Down
11 changes: 6 additions & 5 deletions colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .._compatibility import check_meta_compatibility

if check_meta_compatibility():
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module
from .memory import calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out

from .dataflow import GraphInfo
from .memory import parameter_size, activation_size, is_inplace
from .memory import activation_size, is_inplace, parameter_size
79 changes: 0 additions & 79 deletions colossalai/fx/profiler/constant.py

This file was deleted.

32 changes: 32 additions & 0 deletions colossalai/fx/profiler/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file colossalai/fx/profiler/constant.py is deleted and then re-created? You should change the file directly instead of deleting it to minimize git change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split this file into two, so git cannot identify this as a filename change.


__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']

aten = torch.ops.aten

ALIAS_ATEN = [
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]

INPLACE_NEW = [
aten.empty_like.default,
aten.new_empty_strided.default,
]

INPLACE_MATH_ATEN = [
aten.add_.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.bernoulli_.float,
]

CLONE_ATEN = [
aten.clone.default,
]
5 changes: 5 additions & 0 deletions colossalai/fx/profiler/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from enum import Enum
from functools import partial
from typing import Dict, List

from torch.fx import Graph, Node

from .._compatibility import compatibility
from .memory import activation_size, is_inplace


Expand All @@ -12,6 +15,7 @@ class Phase(Enum):
PLACEHOLDER = 2


@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
Expand Down Expand Up @@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase


@compatibility(is_backward_compatible=False)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.
Expand Down