Skip to content

Commit

Permalink
Feature/k800 qemu (#1145)
Browse files Browse the repository at this point in the history
* add cmake of k800 toolchain
* add xpu in config.toml
* distributed type for where
* fix noptq
  • Loading branch information
xhuohai committed Dec 20, 2023
1 parent 980f977 commit 3a53346
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 49 deletions.
11 changes: 8 additions & 3 deletions src/Nncase.Core/IR/Tensors/Where.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@ public sealed partial class Where : Op
/// <summary>
/// Gets condition.
/// </summary>
public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond");
public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond", ParameterKind.Input);

/// <summary>
/// Gets x.
/// </summary>
public static readonly ParameterInfo X = new(typeof(Where), 1, "x");
public static readonly ParameterInfo X = new(typeof(Where), 1, "x", ParameterKind.Input);

/// <summary>
/// Gets y.
/// </summary>
public static readonly ParameterInfo Y = new(typeof(Where), 2, "y");
public static readonly ParameterInfo Y = new(typeof(Where), 2, "y", ParameterKind.Input);

public bool IsTfWhere { get; }

public override string DisplayProperty()
{
return $"IsTfWhere: {IsTfWhere}";
}
}
74 changes: 67 additions & 7 deletions src/Nncase.Evaluator/Tensors/Where.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.IR;
Expand Down Expand Up @@ -45,9 +46,20 @@ public IValue Visit(IEvaluateContext context, Where where)
/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, Where target)
{
var cond = context.CheckArgumentType<TensorType>(target, Where.Cond);
var x = context.CheckArgumentType<TensorType>(target, Where.X);
var y = context.CheckArgumentType<TensorType>(target, Where.Y);
var cond = context.CheckArgumentType<IRType>(target, Where.Cond);
var x = context.CheckArgumentType<IRType>(target, Where.X);
var y = context.CheckArgumentType<IRType>(target, Where.Y);

return (cond, x, y) switch
{
(DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target),
(TensorType a, TensorType b, TensorType c) => Visit(a, b, c, target),
_ => new InvalidType(cond.GetType().ToString()),
};
}

public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target)
{
if (target.IsTfWhere)
{
return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank));
Expand All @@ -56,12 +68,60 @@ public IRType Visit(ITypeInferenceContext context, Where target)
return TypeInference.BroadcastType(x.DType, cond, x, y);
}

public IRType Visit(DistributedType cond, DistributedType x, DistributedType y, Where target)
{
var invalid = new InvalidType($"{cond}, {x}, {y} not support");
if (cond.Placement != x.Placement || x.Placement != y.Placement)
{
return invalid;
}

if (target.IsTfWhere)
{
return invalid;
}

var targetType = (TensorType)TypeInference.BroadcastType(x.TensorType.DType, cond.TensorType, x.TensorType, y.TensorType);
if (cond.TensorType.Shape != targetType.Shape)
{
return invalid;
}

var ndsbp = new SBP[cond.Placement.Rank];

for (int i = 0; i < cond.Placement.Rank; i++)
{
switch (cond.NdSBP[i], x.NdSBP[i], y.NdSBP[i])
{
case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPSplit { Axis: int }):
ndsbp[i] = SBP.S(ic);
break;
case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPSplit { Axis: int }):
ndsbp[i] = SBP.S(ic);
break;
case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPBroadCast):
ndsbp[i] = SBP.S(ic);
break;
case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPBroadCast):
ndsbp[i] = SBP.S(ic);
break;
case (SBPBroadCast, SBPBroadCast, SBPBroadCast):
ndsbp[i] = SBP.B;
break;
default:
return invalid;
}
}

return new DistributedType(targetType, ndsbp, cond.Placement);
}

public Cost Visit(ICostEvaluateContext context, Where target)
{
var cond = context.GetArgumentType<TensorType>(target, Where.Cond);
var x = context.GetArgumentType<TensorType>(target, Where.X);
var y = context.GetArgumentType<TensorType>(target, Where.Y);
var ret = context.GetReturnType<TensorType>();
var cond = context.GetArgumentType<IRType>(target, Where.Cond);
var x = context.GetArgumentType<IRType>(target, Where.X);
var y = context.GetArgumentType<IRType>(target, Where.Y);
var ret = context.GetReturnType<IRType>();
return new()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(cond, x, y),
Expand Down
38 changes: 0 additions & 38 deletions src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs

This file was deleted.

13 changes: 13 additions & 0 deletions tests/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,16 @@ threshold = 0.999
[target.k230.mode.ptq]
enabled = true
threshold = 0.96

[target.xpu]
eval = false
infer = true
similarity_name = 'cosine'

[target.xpu.mode.noptq]
enabled = true
threshold = 0.999

[target.xpu.mode.ptq]
enabled = false
threshold = 0.9
2 changes: 1 addition & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def run(self, model_file: Union[List[str], str]):
actual = self.run_evaluator(compiler, tmp_dir)
else:
actual = self.run_inference(
compiler, k_target, v_mode['enabled'], tmp_dir)
compiler, k_target, k_mode == "ptq" and v_mode['enabled'], tmp_dir)
target_dir = os.path.join(self.case_dir, stage, k_target)
os.makedirs(target_dir, exist_ok=True)
mode_dir = os.path.join(target_dir, k_mode)
Expand Down
29 changes: 29 additions & 0 deletions toolchains/k800.linux.toolchain.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)

if(DEFINED ENV{RISCV_ROOT_PATH})
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
endif()

if(NOT RISCV_ROOT_PATH)
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
endif()

set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")

set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++")

set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu/")

set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(ENABLE_VULKAN_RUNTIME OFF)
set(ENABLE_OPENMP OFF)
set(ENABLE_VULKAN OFF)
set(ENABLE_HALIDE OFF)
set(BUILD_PYTHON_BINDING OFF)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64gv")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gv")

0 comments on commit 3a53346

Please sign in to comment.