Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/k800 qemu #1145

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/Nncase.Core/IR/Tensors/Where.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,22 @@ public sealed partial class Where : Op
/// <summary>
/// Gets condition.
/// </summary>
public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond");
public static readonly ParameterInfo Cond = new(typeof(Where), 0, "cond", ParameterKind.Input);

/// <summary>
/// Gets x.
/// </summary>
public static readonly ParameterInfo X = new(typeof(Where), 1, "x");
public static readonly ParameterInfo X = new(typeof(Where), 1, "x", ParameterKind.Input);

/// <summary>
/// Gets y.
/// </summary>
public static readonly ParameterInfo Y = new(typeof(Where), 2, "y");
public static readonly ParameterInfo Y = new(typeof(Where), 2, "y", ParameterKind.Input);

public bool IsTfWhere { get; }

public override string DisplayProperty()
{
return $"IsTfWhere: {IsTfWhere}";
}
}
74 changes: 67 additions & 7 deletions src/Nncase.Evaluator/Tensors/Where.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.IR;
Expand Down Expand Up @@ -45,9 +46,20 @@
/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, Where target)
{
var cond = context.CheckArgumentType<TensorType>(target, Where.Cond);
var x = context.CheckArgumentType<TensorType>(target, Where.X);
var y = context.CheckArgumentType<TensorType>(target, Where.Y);
var cond = context.CheckArgumentType<IRType>(target, Where.Cond);
var x = context.CheckArgumentType<IRType>(target, Where.X);
var y = context.CheckArgumentType<IRType>(target, Where.Y);

return (cond, x, y) switch
{
(DistributedType a, DistributedType b, DistributedType c) => Visit(a, b, c, target),

Check warning on line 55 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L55

Added line #L55 was not covered by tests
(TensorType a, TensorType b, TensorType c) => Visit(a, b, c, target),
_ => new InvalidType(cond.GetType().ToString()),

Check warning on line 57 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L57

Added line #L57 was not covered by tests
};
}

public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target)
{
if (target.IsTfWhere)
{
return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown, cond.Shape.Rank));
Expand All @@ -56,12 +68,60 @@
return TypeInference.BroadcastType(x.DType, cond, x, y);
}

public IRType Visit(DistributedType cond, DistributedType x, DistributedType y, Where target)
{
var invalid = new InvalidType($"{cond}, {x}, {y} not support");

Check warning on line 73 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L73

Added line #L73 was not covered by tests
if (cond.Placement != x.Placement || x.Placement != y.Placement)
{
return invalid;

Check warning on line 76 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L76

Added line #L76 was not covered by tests
}

if (target.IsTfWhere)
{
return invalid;

Check warning on line 81 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L81

Added line #L81 was not covered by tests
}

var targetType = (TensorType)TypeInference.BroadcastType(x.TensorType.DType, cond.TensorType, x.TensorType, y.TensorType);

Check warning on line 84 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L84

Added line #L84 was not covered by tests
if (cond.TensorType.Shape != targetType.Shape)
{
return invalid;

Check warning on line 87 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L87

Added line #L87 was not covered by tests
}

var ndsbp = new SBP[cond.Placement.Rank];

Check warning on line 90 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L90

Added line #L90 was not covered by tests

for (int i = 0; i < cond.Placement.Rank; i++)
{
switch (cond.NdSBP[i], x.NdSBP[i], y.NdSBP[i])
{
case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPSplit { Axis: int }):
ndsbp[i] = SBP.S(ic);
break;

Check warning on line 98 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L97-L98

Added lines #L97 - L98 were not covered by tests
case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPSplit { Axis: int }):
ndsbp[i] = SBP.S(ic);
break;

Check warning on line 101 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L100-L101

Added lines #L100 - L101 were not covered by tests
case (SBPSplit { Axis: int ic }, SBPSplit { Axis: int }, SBPBroadCast):
ndsbp[i] = SBP.S(ic);
break;

Check warning on line 104 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L103-L104

Added lines #L103 - L104 were not covered by tests
case (SBPSplit { Axis: int ic }, SBPBroadCast, SBPBroadCast):
ndsbp[i] = SBP.S(ic);
break;

Check warning on line 107 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L106-L107

Added lines #L106 - L107 were not covered by tests
case (SBPBroadCast, SBPBroadCast, SBPBroadCast):
ndsbp[i] = SBP.B;
break;

Check warning on line 110 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L109-L110

Added lines #L109 - L110 were not covered by tests
default:
return invalid;

Check warning on line 112 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L112

Added line #L112 was not covered by tests
}
}

return new DistributedType(targetType, ndsbp, cond.Placement);

Check warning on line 116 in src/Nncase.Evaluator/Tensors/Where.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/Tensors/Where.cs#L116

Added line #L116 was not covered by tests
}

public Cost Visit(ICostEvaluateContext context, Where target)
{
var cond = context.GetArgumentType<TensorType>(target, Where.Cond);
var x = context.GetArgumentType<TensorType>(target, Where.X);
var y = context.GetArgumentType<TensorType>(target, Where.Y);
var ret = context.GetReturnType<TensorType>();
var cond = context.GetArgumentType<IRType>(target, Where.Cond);
var x = context.GetArgumentType<IRType>(target, Where.X);
var y = context.GetArgumentType<IRType>(target, Where.Y);
var ret = context.GetReturnType<IRType>();
return new()
{
[CostFactorNames.MemoryLoad] = CostUtility.GetMemoryAccess(cond, x, y),
Expand Down
38 changes: 0 additions & 38 deletions src/Nncase.Passes/Rules/Neutral/FoldPrePostReshapeSoftmax.cs

This file was deleted.

13 changes: 13 additions & 0 deletions tests/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,16 @@ threshold = 0.999
[target.k230.mode.ptq]
enabled = true
threshold = 0.96

[target.xpu]
eval = false
infer = true
similarity_name = 'cosine'

[target.xpu.mode.noptq]
enabled = true
threshold = 0.999

[target.xpu.mode.ptq]
enabled = false
threshold = 0.9
2 changes: 1 addition & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def run(self, model_file: Union[List[str], str]):
actual = self.run_evaluator(compiler, tmp_dir)
else:
actual = self.run_inference(
compiler, k_target, v_mode['enabled'], tmp_dir)
compiler, k_target, k_mode == "ptq" and v_mode['enabled'], tmp_dir)
target_dir = os.path.join(self.case_dir, stage, k_target)
os.makedirs(target_dir, exist_ok=True)
mode_dir = os.path.join(target_dir, k_mode)
Expand Down
29 changes: 29 additions & 0 deletions toolchains/k800.linux.toolchain.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)

if(DEFINED ENV{RISCV_ROOT_PATH})
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
endif()

if(NOT RISCV_ROOT_PATH)
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
endif()

set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")

set(CMAKE_C_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++")

set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu/")

set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(ENABLE_VULKAN_RUNTIME OFF)
set(ENABLE_OPENMP OFF)
set(ENABLE_VULKAN OFF)
set(ENABLE_HALIDE OFF)
set(BUILD_PYTHON_BINDING OFF)

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=rv64gv")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=rv64gv")
Loading