Skip to content

Commit

Permalink
Cherry pick LLaMA or SDXL to 1.16.2 release (round 3) (#18323)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 8, 2023
1 parent 0ccca88 commit 8f06330
Show file tree
Hide file tree
Showing 29 changed files with 916 additions and 88 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ half maybe2half(float x) {

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192};
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
Expand Down Expand Up @@ -206,7 +206,7 @@ void LaunchSkipLayerNormKernel(
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \
if constexpr (next_size_value >= 8 * 256) { \
if constexpr (next_size_value >= 320) { \
if (can_unroll_vec8) { \
constexpr int block_size = next_size_value / 8; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \
Expand Down Expand Up @@ -239,6 +239,9 @@ void LaunchSkipLayerNormKernel(
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
CASE_NEXT_SIZE(kSizes[7]);
CASE_NEXT_SIZE(kSizes[8]);
CASE_NEXT_SIZE(kSizes[9]);
CASE_NEXT_SIZE(kSizes[10]);
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot);
DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding);

DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
Expand All @@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention);
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 1> typeNameListDefaultV = {"V"};
constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
Expand Down Expand Up @@ -597,6 +599,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};

Expand Down Expand Up @@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},

{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ namespace AttrName

static constexpr const char* GraphFusedActivation = "activation";
static constexpr const char* GraphFusedAxis = "activation_axis";
static constexpr const char* Interleaved = "interleaved";

} // namespace AttrName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
} // namespace MsftOperatorSet1

} // namespace OperatorHelper
2 changes: 0 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) {
})
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid");

PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"MatMulInteger16": self._infer_MatMulInteger,
"MaxPool": self._infer_Pool,
"Max": self._infer_symbolic_compute_ops,
"MemcpyFromHost": self._pass_on_shape_and_type,
"MemcpyToHost": self._pass_on_shape_and_type,
"Min": self._infer_symbolic_compute_ops,
"Mul": self._infer_symbolic_compute_ops,
"NonMaxSuppression": self._infer_NonMaxSuppression,
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def make_value_info_from_tensor(tensor):


# Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]}
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}


class InitializerTracker:
Expand Down
22 changes: 3 additions & 19 deletions onnxruntime/python/tools/transformers/fusion_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
return

instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
if instance_norm_scale is None:
return
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None:
return

# Only groups=32 is supported in GroupNorm kernel. Check the scale and bias is 1D tensor with shape [32].
if not (len(instance_norm_scale.shape) == 1 and instance_norm_scale.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since scale shape is expected to be [32], Got %s", str(instance_norm_scale.shape)
)
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
return

if not (len(instance_norm_bias.shape) == 1 and instance_norm_bias.shape[0] == 32):
logger.debug(
"Skip GroupNorm fusion since bias shape is expected to be [32], Got %s", str(instance_norm_bias.shape)
)
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
return

if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
Expand All @@ -108,10 +96,6 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):

group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")

if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]:
logger.info("Skip GroupNorm fusion since channels=%d is not supported.", weight_elements)
return

self.add_initializer(
name=group_norm_name + "_gamma",
data_type=TensorProto.FLOAT,
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, model_type):
if model_type in ["unet", "vae", "clip"]:
self.enable_nhwc_conv = True
self.enable_group_norm = True
self.enable_skip_group_norm = True
self.enable_bias_splitgelu = True
self.enable_packed_qkv = True
self.enable_packed_kv = True
Expand Down Expand Up @@ -116,6 +117,8 @@ def parse(args):
options.enable_nhwc_conv = False
if args.disable_group_norm:
options.enable_group_norm = False
if args.disable_skip_group_norm:
options.enable_skip_group_norm = False
if args.disable_bias_splitgelu:
options.enable_bias_splitgelu = False
if args.disable_packed_qkv:
Expand Down Expand Up @@ -250,6 +253,14 @@ def add_arguments(parser: ArgumentParser):
)
parser.set_defaults(disable_group_norm=False)

parser.add_argument(
"--disable_skip_group_norm",
required=False,
action="store_true",
help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
)
parser.set_defaults(disable_skip_group_norm=False)

parser.add_argument(
"--disable_packed_kv",
required=False,
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def __init__(
hidden_size,
num_heads,
use_multi_head_attention=True,
search_op_types=["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Add"],
search_op_types=[
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"LayerNormalization",
"SkipLayerNormalization",
"Add",
],
)

def create_mha_node(
Expand Down Expand Up @@ -318,7 +324,7 @@ def check_runtime_shape_paths_for_nodes(
return True

def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
return

# qkv_nodes_1 is for LLaMA-2 Microsoft
Expand Down

0 comments on commit 8f06330

Please sign in to comment.