diff --git a/src/Nncase.Core/IR/Tensors/Where.cs b/src/Nncase.Core/IR/Tensors/Where.cs index 5506b6afea..d64e2472ad 100644 --- a/src/Nncase.Core/IR/Tensors/Where.cs +++ b/src/Nncase.Core/IR/Tensors/Where.cs @@ -21,17 +21,22 @@ public sealed partial class Where : Op /// /// Gets condition. /// - public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond"); + public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond", ParameterKind.Input); /// /// Gets x. /// - public static readonly ParameterInfo X = new(typeof(Where), 1, "x"); + public static readonly ParameterInfo X = new(typeof(Where), 1, "x", ParameterKind.Input); /// /// Gets y. /// - public static readonly ParameterInfo Y = new(typeof(Where), 2, "y"); + public static readonly ParameterInfo Y = new(typeof(Where), 2, "y", ParameterKind.Input); public bool IsTfWhere { get; } + + public override string DisplayProperty() + { + return $"IsTfWhere: {IsTfWhere}"; + } } diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs index db614bc6d9..f45d41b737 100644 --- a/src/Nncase.Evaluator/Tensors/Where.cs +++ b/src/Nncase.Evaluator/Tensors/Where.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.RegularExpressions; using NetFabric.Hyperlinq; using Nncase.CostModel; using Nncase.IR; @@ -45,9 +46,20 @@ public IValue Visit(IEvaluateContext context, Where where) /// public IRType Visit(ITypeInferenceContext context, Where target) { - var cond = context.CheckArgumentType(target, Where.Cond); - var x = context.CheckArgumentType(target, Where.X); - var y = context.CheckArgumentType(target, Where.Y); + var cond = context.CheckArgumentType(target, Where.Cond); + var x = context.CheckArgumentType(target, Where.X); + var y = context.CheckArgumentType(target, Where.Y); + + return (cond, x, y) switch + { + (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target), + (TensorType a, TensorType b, TensorType c) => Visit(a, b, c, target), + _ => new InvalidType(cond.GetType().ToString()), + }; + } + + public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target) + { if (target.IsTfWhere) { return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank)); @@ -56,12 +68,60 @@ public IRType Visit(ITypeInferenceContext context, Where target) return TypeInference.BroadcastType(x.DType, cond, x, y); } + public IRType Visit(DistributedType cond, DistributedType x, DistributedType y, Where target) + { + var invalid = new InvalidType($"{cond}, {x}, {y} not support"); + if (cond.Placement != x.Placement || x.Placement != y.Placement) + { + return invalid; + } + + if (target.IsTfWhere) + { + return invalid; + } + + var targetType = (TensorType)TypeInference.BroadcastType(x.TensorType.DType, cond.TensorType, x.TensorType, y.TensorType); + if (cond.TensorType.Shape != targetType.Shape) + { + return invalid; + } + + var ndsbp = new SBP[cond.Placement.Rank]; + + for (int i = 0; i < cond.Placement.Rank; i++) + { + switch (cond.NdSBP[i], x.NdSBP[i], y.NdSBP[i]) + { + case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPSplit { Axis: int }): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPSplit { Axis: int }): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPBroadCast): + ndsbp[i] = SBP.S(ic); + break; + case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.S(ic); + break; + case (SBPBroadCast, SBPBroadCast, SBPBroadCast): + ndsbp[i] = SBP.B; + break; + default: + return invalid; + } + } + + return new DistributedType(targetType, ndsbp, cond.Placement); + } + public Cost Visit(ICostEvaluateContext context, Where target) { - var cond = context.GetArgumentType(target, Where.Cond); - var x = context.GetArgumentType(target, Where.X); - var y = context.GetArgumentType(target, Where.Y); - var ret = context.GetReturnType(); + var cond = context.GetArgumentType(target, Where.Cond); + var x = context.GetArgumentType(target, Where.X); + var y = context.GetArgumentType(target, Where.Y); + var ret = context.GetReturnType(); return new() { [CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(cond, x, y), diff --git a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs b/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs deleted file mode 100644 index 83edfe5e5e..0000000000 --- a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Nncase.IR; -using Nncase.IR.NN; -using Nncase.PatternMatch; -using static Nncase.IR.F.NN; -using static Nncase.IR.F.Tensors; -using static Nncase.IR.TypePatternUtility; -using static Nncase.PatternMatch.F.Math; -using static Nncase.PatternMatch.F.NN; -using static Nncase.PatternMatch.F.Tensors; -using static Nncase.PatternMatch.Utility; - -namespace Nncase.Passes.Rules.Neutral; - -/// -/// Fold nop . -/// -[RuleGenerator] -public sealed partial class FoldPrePostReshapeSoftmax : IRewriteRule -{ - /// - public IPattern Pattern { get; } = IsReshape( - "reshape", - "reshapeCall", - _ => true, - IsSoftmax("softmax", IsReshape("rehsape2", "reshapeCall2", _ => true, IsWildcard("input"), IsTensorConst("shape2"))), - IsTensorConst("shape1")); - - private Expr? GetReplace(Expr input) - { - return Softmax(input, 3); - } -} diff --git a/tests/config.toml b/tests/config.toml index 56f1c0f1bb..41c7216f3d 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -131,3 +131,16 @@ threshold = 0.999 [target.k230.mode.ptq] enabled = true threshold = 0.96 + +[target.xpu] +eval = false +infer = true +similarity_name = 'cosine' + +[target.xpu.mode.noptq] +enabled = true +threshold = 0.999 + +[target.xpu.mode.ptq] +enabled = false +threshold = 0.9 \ No newline at end of file diff --git a/tests/test_runner.py b/tests/test_runner.py index 3aa5ce0777..58906311d2 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -277,7 +277,7 @@ def run(self, model_file: Union[List[str], str]): actual = self.run_evaluator(compiler, tmp_dir) else: actual = self.run_inference( - compiler, k_target, v_mode['enabled'], tmp_dir) + compiler, k_target, k_mode == "ptq" and v_mode['enabled'], tmp_dir) target_dir = os.path.join(self.case_dir, stage, k_target) os.makedirs(target_dir, exist_ok=True) mode_dir = os.path.join(target_dir, k_mode) diff --git a/toolchains/k800.linux.toolchain.cmake b/toolchains/k800.linux.toolchain.cmake new file mode 100644 index 0000000000..660754fc95 --- /dev/null +++ b/toolchains/k800.linux.toolchain.cmake @@ -0,0 +1,29 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +if(DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) +endif() + +if(NOT RISCV_ROOT_PATH) + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") +endif() + +set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") + +set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu/") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(ENABLE_VULKAN_RUNTIME OFF) +set(ENABLE_OPENMP OFF) +set(ENABLE_VULKAN OFF) +set(ENABLE_HALIDE OFF) +set(BUILD_PYTHON_BINDING OFF) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64gv") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gv")