diff --git a/src/Nncase.Core/IR/Tensors/Where.cs b/src/Nncase.Core/IR/Tensors/Where.cs
index 5506b6afea..d64e2472ad 100644
--- a/src/Nncase.Core/IR/Tensors/Where.cs
+++ b/src/Nncase.Core/IR/Tensors/Where.cs
@@ -21,17 +21,22 @@ public sealed partial class Where : Op
///
/// Gets condition.
///
- public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond");
+ public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond", ParameterKind.Input);
///
/// Gets x.
///
- public static readonly ParameterInfo X = new(typeof(Where), 1, "x");
+ public static readonly ParameterInfo X = new(typeof(Where), 1, "x", ParameterKind.Input);
///
/// Gets y.
///
- public static readonly ParameterInfo Y = new(typeof(Where), 2, "y");
+ public static readonly ParameterInfo Y = new(typeof(Where), 2, "y", ParameterKind.Input);
public bool IsTfWhere { get; }
+
+ public override string DisplayProperty()
+ {
+ return $"IsTfWhere: {IsTfWhere}";
+ }
}
diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs
index db614bc6d9..f45d41b737 100644
--- a/src/Nncase.Evaluator/Tensors/Where.cs
+++ b/src/Nncase.Evaluator/Tensors/Where.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Text.RegularExpressions;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.IR;
@@ -45,9 +46,20 @@ public IValue Visit(IEvaluateContext context, Where where)
///
public IRType Visit(ITypeInferenceContext context, Where target)
{
- var cond = context.CheckArgumentType(target, Where.Cond);
- var x = context.CheckArgumentType(target, Where.X);
- var y = context.CheckArgumentType(target, Where.Y);
+ var cond = context.CheckArgumentType(target, Where.Cond);
+ var x = context.CheckArgumentType(target, Where.X);
+ var y = context.CheckArgumentType(target, Where.Y);
+
+ return (cond, x, y) switch
+ {
+ (DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target),
+ (TensorType a, TensorType b, TensorType c) => Visit(a, b, c, target),
+ _ => new InvalidType(cond.GetType().ToString()),
+ };
+ }
+
+ public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target)
+ {
if (target.IsTfWhere)
{
return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank));
@@ -56,12 +68,60 @@ public IRType Visit(ITypeInferenceContext context, Where target)
return TypeInference.BroadcastType(x.DType, cond, x, y);
}
+ public IRType Visit(DistributedType cond, DistributedType x, DistributedType y, Where target)
+ {
+ var invalid = new InvalidType($"{cond}, {x}, {y} not support");
+ if (cond.Placement != x.Placement || x.Placement != y.Placement)
+ {
+ return invalid;
+ }
+
+ if (target.IsTfWhere)
+ {
+ return invalid;
+ }
+
+ var targetType = (TensorType)TypeInference.BroadcastType(x.TensorType.DType, cond.TensorType, x.TensorType, y.TensorType);
+ if (cond.TensorType.Shape != targetType.Shape)
+ {
+ return invalid;
+ }
+
+ var ndsbp = new SBP[cond.Placement.Rank];
+
+ for (int i = 0; i < cond.Placement.Rank; i++)
+ {
+ switch (cond.NdSBP[i], x.NdSBP[i], y.NdSBP[i])
+ {
+ case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPSplit { Axis: int }):
+ ndsbp[i] = SBP.S(ic);
+ break;
+ case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPSplit { Axis: int }):
+ ndsbp[i] = SBP.S(ic);
+ break;
+ case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPBroadCast):
+ ndsbp[i] = SBP.S(ic);
+ break;
+ case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPBroadCast):
+ ndsbp[i] = SBP.S(ic);
+ break;
+ case (SBPBroadCast, SBPBroadCast, SBPBroadCast):
+ ndsbp[i] = SBP.B;
+ break;
+ default:
+ return invalid;
+ }
+ }
+
+ return new DistributedType(targetType, ndsbp, cond.Placement);
+ }
+
public Cost Visit(ICostEvaluateContext context, Where target)
{
- var cond = context.GetArgumentType(target, Where.Cond);
- var x = context.GetArgumentType(target, Where.X);
- var y = context.GetArgumentType(target, Where.Y);
- var ret = context.GetReturnType();
+ var cond = context.GetArgumentType(target, Where.Cond);
+ var x = context.GetArgumentType(target, Where.X);
+ var y = context.GetArgumentType(target, Where.Y);
+ var ret = context.GetReturnType();
return new()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(cond, x, y),
diff --git a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs b/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs
deleted file mode 100644
index 83edfe5e5e..0000000000
--- a/src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) Canaan Inc. All rights reserved.
-// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using Nncase.IR;
-using Nncase.IR.NN;
-using Nncase.PatternMatch;
-using static Nncase.IR.F.NN;
-using static Nncase.IR.F.Tensors;
-using static Nncase.IR.TypePatternUtility;
-using static Nncase.PatternMatch.F.Math;
-using static Nncase.PatternMatch.F.NN;
-using static Nncase.PatternMatch.F.Tensors;
-using static Nncase.PatternMatch.Utility;
-
-namespace Nncase.Passes.Rules.Neutral;
-
-///
-/// Fold nop .
-///
-[RuleGenerator]
-public sealed partial class FoldPrePostReshapeSoftmax : IRewriteRule
-{
- ///
- public IPattern Pattern { get; } = IsReshape(
- "reshape",
- "reshapeCall",
- _ => true,
- IsSoftmax("softmax", IsReshape("rehsape2", "reshapeCall2", _ => true, IsWildcard("input"), IsTensorConst("shape2"))),
- IsTensorConst("shape1"));
-
- private Expr? GetReplace(Expr input)
- {
- return Softmax(input, 3);
- }
-}
diff --git a/tests/config.toml b/tests/config.toml
index 56f1c0f1bb..41c7216f3d 100644
--- a/tests/config.toml
+++ b/tests/config.toml
@@ -131,3 +131,16 @@ threshold = 0.999
[target.k230.mode.ptq]
enabled = true
threshold = 0.96
+
+[target.xpu]
+eval = false
+infer = true
+similarity_name = 'cosine'
+
+[target.xpu.mode.noptq]
+enabled = true
+threshold = 0.999
+
+[target.xpu.mode.ptq]
+enabled = false
+threshold = 0.9
\ No newline at end of file
diff --git a/tests/test_runner.py b/tests/test_runner.py
index 3aa5ce0777..58906311d2 100644
--- a/tests/test_runner.py
+++ b/tests/test_runner.py
@@ -277,7 +277,7 @@ def run(self, model_file: Union[List[str], str]):
actual = self.run_evaluator(compiler, tmp_dir)
else:
actual = self.run_inference(
- compiler, k_target, v_mode['enabled'], tmp_dir)
+ compiler, k_target, k_mode == "ptq" and v_mode['enabled'], tmp_dir)
target_dir = os.path.join(self.case_dir, stage, k_target)
os.makedirs(target_dir, exist_ok=True)
mode_dir = os.path.join(target_dir, k_mode)
diff --git a/toolchains/k800.linux.toolchain.cmake b/toolchains/k800.linux.toolchain.cmake
new file mode 100644
index 0000000000..660754fc95
--- /dev/null
+++ b/toolchains/k800.linux.toolchain.cmake
@@ -0,0 +1,29 @@
+set(CMAKE_SYSTEM_NAME Linux)
+set(CMAKE_SYSTEM_PROCESSOR riscv64)
+
+if(DEFINED ENV{RISCV_ROOT_PATH})
+ file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
+endif()
+
+if(NOT RISCV_ROOT_PATH)
+ message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
+endif()
+
+set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
+
+set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc")
+set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++")
+
+set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu/")
+
+set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
+set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
+set(ENABLE_VULKAN_RUNTIME OFF)
+set(ENABLE_OPENMP OFF)
+set(ENABLE_VULKAN OFF)
+set(ENABLE_HALIDE OFF)
+set(BUILD_PYTHON_BINDING OFF)
+
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64gv")
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gv")