Skip to content

Commit

Permalink
Merge branch 'master' into fix-runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
FusionBolt committed Nov 16, 2023
2 parents 7b99ec1 + 050de8e commit 0a06ba0
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 95 deletions.
163 changes: 91 additions & 72 deletions src/Native/src/kernels/stackvm/reference/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,80 +28,99 @@ using namespace nncase::kernels::stackvm;
namespace {
// softmax(x) = exp(x - reduce_max(x)) / reduce_sum(exp(x - reduce_max(x)))
template <typename T>
result<void> softmax_impl(const T *input, T *output,
gsl::span<const size_t> in_shape,
gsl::span<const size_t> in_strides,
gsl::span<const size_t> out_strides, int64_t axis,
float beta, bool needLog = false) noexcept {
result<void>
softmax_impl(const T *input, T *output, gsl::span<const size_t> in_shape,
NNCASE_UNUSED gsl::span<const size_t> in_strides,
NNCASE_UNUSED gsl::span<const size_t> out_strides, int64_t axis,
float beta, bool needLog = false) noexcept {
size_t positive_axis = axis < 0 ? in_shape.size() + axis : axis;
dims_t axes{positive_axis};

auto reduced_shape =
kernels::detail::get_reduced_shape(in_shape, axes, true);
auto reduced_strides = get_default_strides(reduced_shape);
auto reduced_size = compute_size(reduced_shape);
std::vector<T> tmp(reduced_size, std::numeric_limits<T>::lowest());

// reduce_max
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(in_strides, index);
const auto in = input[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto out_idx = offset(reduced_strides, out_index);
auto &out = tmp[out_idx];

out = std::max(in, out);
return ok();
}));

// x - reduce_max
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(in_strides, index);
const auto in = input[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto max_idx = offset(reduced_strides, out_index);

auto out_idx = offset(out_strides, index);
output[out_idx] =
static_cast<T>(static_cast<float>(in - tmp[max_idx]) * beta);

return ok();
}));

// exp(x - reduce_max) and sum
tmp.assign(tmp.size(), static_cast<T>(0));
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
auto in_idx = offset(out_strides, index);
const auto in = output[in_idx];

const auto out_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto out_idx = offset(reduced_strides, out_index);
output[in_idx] = static_cast<T>(expf(static_cast<float>(in)));
tmp[out_idx] += static_cast<T>(output[in_idx]);

return ok();
}));

// div
try_(apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
const auto in_index =
kernels::detail::get_reduced_offset(index, axes, true);
auto in_idx = offset(reduced_strides, in_index);
auto in = tmp[in_idx];

auto out_idx = offset(out_strides, index);
auto &out = output[out_idx];
out /= in;
if (needLog) {
out = static_cast<T>(std::log(static_cast<float>(out)));

if (positive_axis == in_shape.size() - 1) {
size_t reduced_size = in_shape[positive_axis];
auto out_size = compute_size(in_shape) / reduced_size;
std::vector<T> tmp(reduced_size, std::numeric_limits<T>::lowest());

for (size_t i = 0; i < out_size; i++) {
auto in_ = input + i * reduced_size;
auto out_ = output + i * reduced_size;

// reduce_max
auto max_value = *in_;
for (size_t j = 0; j < reduced_size; j++) {
max_value = std::max(max_value, in_[j]);
}

// (x - reduce_max) * beta
for (size_t j = 0; j < reduced_size; j++) {
out_[j] = static_cast<T>((static_cast<float>(in_[j]) -
static_cast<float>(max_value)) *
beta);
}

// exp((x - reduce_max) * beta) and sum
T sum = 0;
for (size_t j = 0; j < reduced_size; j++) {
out_[j] = static_cast<T>(expf(static_cast<float>(out_[j])));
sum += out_[j];
}

// div
for (size_t j = 0; j < reduced_size; j++) {
out_[j] /= sum;
if (needLog) {
out_[j] =
static_cast<T>(std::log(static_cast<float>(out_[j])));
}
}
}
} else {
size_t axis_size = in_shape[positive_axis];
size_t reduced_size = 1;
for (size_t i = positive_axis + 1; i < in_shape.size(); i++) {
reduced_size *= in_shape[i];
}
return ok();
}));
auto out_size = compute_size(in_shape) / reduced_size / axis_size;

for (size_t i = 0; i < out_size; i++) {
std::vector<T> axis_sum(reduced_size, static_cast<T>(0));
std::vector<T> max_value(reduced_size,
std::numeric_limits<T>::lowest());
auto in_ = input + i * reduced_size * axis_size;
auto out_ = output + i * reduced_size * axis_size;

// reduce_max
for (size_t k = 0; k < axis_size; k++) {
auto in_k = in_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
max_value[j] = std::max(max_value[j], in_k[j]);
}
}

// exp((x - reduce_max) * beta) and sum
for (size_t k = 0; k < axis_size; k++) {
auto in_k = in_ + k * reduced_size;
auto out_k = out_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
out_k[j] =
static_cast<T>(expf((static_cast<float>(in_k[j]) -
static_cast<float>(max_value[j])) *
beta));
axis_sum[j] += out_k[j];
}
}

// div
for (size_t k = 0; k < axis_size; k++) {
auto out_k = out_ + k * reduced_size;
for (size_t j = 0; j < reduced_size; j++) {
out_k[j] /= axis_sum[j];
if (needLog)
out_k[j] = static_cast<T>(
std::log(static_cast<float>((out_k[j]))));
}
}
}
}

return ok();
}
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Expr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public DataType CheckedDataType
{
case TensorType type:
return type.DType;
case DistributedType type:
return type.TensorType.DType;

Check warning on line 120 in src/Nncase.Core/IR/Expr.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/IR/Expr.cs#L120

Added line #L120 was not covered by tests
default:
if (DumpScope.Current.IsEnabled(DumpFlags.Compile))
{
Expand Down
93 changes: 78 additions & 15 deletions src/Nncase.Core/Utilities/DistributedUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens
var ndsbp = new List<SBP>();
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]))
{
ndsbp.Add(SBP.S(axis));
}
Expand All @@ -28,7 +28,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens

return ndsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).

Check warning on line 31 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L31

Added line #L31 was not covered by tests
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}
Expand All @@ -53,7 +53,7 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT
candidateNdsbps[i].Add(SBP.B);
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
{
candidateNdsbps[i].Add(SBP.S(axis));
}
Expand All @@ -67,38 +67,101 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT

return candidateNdsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).

Check warning on line 70 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L70

Added line #L70 was not covered by tests
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}

public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement, [MaybeNullWhen(false)] out TensorType distType)
public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement)
{
distType = null;
if (!tensorType.Shape.IsFixed)
{
return false;
}

var shape = tensorType.Shape.ToValueArray();
for (int i = 0; i < ndsbp.Length; i++)
var divisors = GetDivisors(new DistributedType(tensorType, new IRArray<SBP>(ndsbp.ToArray()), placement));

Check warning on line 82 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L82

Added line #L82 was not covered by tests
return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : IsDivideBy(tensorType.Shape[p.axis].FixedValue, p.d));
}

public static IReadOnlyList<int> GetDivisors(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();
var divisors = Enumerable.Repeat(0, shape.Length).ToArray();

Check warning on line 89 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L88-L89

Added lines #L88 - L89 were not covered by tests
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (ndsbp[i] is SBPSplit { Axis: int axis })
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
if (!IsDivisible(shape[axis], placement.Hierarchy[i]))
if (divisors[axis] == 0)
{
return false;
divisors[axis] = 1;

Check warning on line 96 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L96

Added line #L96 was not covered by tests
}

shape[axis] /= placement.Hierarchy[i];
divisors[axis] *= distributedType.Placement.Hierarchy[i];

Check warning on line 99 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L99

Added line #L99 was not covered by tests
}
}

distType = tensorType with { Shape = shape };
return true;
return divisors;

Check warning on line 103 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L103

Added line #L103 was not covered by tests
}

public static bool TryGetDividedTensorType(DistributedType distributedType, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TensorType tensorType)
{
tensorType = null;
var divisors = GetDivisors(distributedType);

Check warning on line 109 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L108-L109

Added lines #L108 - L109 were not covered by tests
if (divisors.Select((d, i) => (d, i)).All(p => p.d == 0 || IsDivideExactly(distributedType.TensorType.Shape[p.i].FixedValue, p.d)))
{
tensorType = new TensorType(distributedType.TensorType.DType, distributedType.TensorType.Shape.Zip(divisors).Select(p => p.Second == 0 ? p.First.FixedValue : p.First.FixedValue / p.Second).ToArray());
return true;

Check warning on line 113 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L113

Added line #L113 was not covered by tests
}

return false;

Check warning on line 116 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L116

Added line #L116 was not covered by tests
}

public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();

Check warning on line 121 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L121

Added line #L121 was not covered by tests
var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List<int>()).ToArray();
var ids = distributedType.Placement.Name.Select(c => new Var(c + "id", TensorType.Scalar(DataTypes.Int32))).ToArray();
var hierarchyStrides = TensorUtilities.GetStrides(distributedType.Placement.Hierarchy.ToArray());

Check warning on line 124 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L124

Added line #L124 was not covered by tests
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
hierarchies[axis].Add(i);

Check warning on line 129 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L129

Added line #L129 was not covered by tests
}
}

return hierarchies.Select((divs, axis) =>
{
Expr dim;

Check warning on line 135 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L133-L135

Added lines #L133 - L135 were not covered by tests
if (divs.Any())
{

Check warning on line 137 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L137

Added line #L137 was not covered by tests
var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
var (res, rem) = Math.DivRem(shape[axis], divsor);

Check warning on line 139 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L139

Added line #L139 was not covered by tests
dim = IR.F.Math.Select(
TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1),
res,
res + rem);
}
else
{
dim = distributedType.TensorType.Shape[axis].FixedValue;
}
return dim;
}).ToArray();

Check warning on line 151 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L141-L151

Added lines #L141 - L151 were not covered by tests
}

public static bool IsDivideBy(int input, int divisor)
{
if (input >= divisor)
{
return true;

Check warning on line 158 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L158

Added line #L158 was not covered by tests
}

return false;

Check warning on line 161 in src/Nncase.Core/Utilities/DistributedUtility.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Core/Utilities/DistributedUtility.cs#L161

Added line #L161 was not covered by tests
}

public static bool IsDivisible(int input, int divisor)
public static bool IsDivideExactly(int input, int divisor)
{
if (input >= divisor && input % divisor == 0)
{
Expand Down
25 changes: 19 additions & 6 deletions src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,26 @@ protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext o
// Normalization
if (mean.Length != 0)
{
newInput = mean.Length switch
Expr meanCall;
Expr stdCall;
switch (mean.Length)
{
3 when inputShape.Length == 4 => (newInput - Tensor.From(mean, new[] { 1, mean.Length, 1, 1 })) /
Tensor.From(std, new[] { 1, std.Length, 1, 1 }),
_ => (newInput - Tensor.From(new float[] { mean[0] }, new[] { 1 })) /
Tensor.From(new float[] { std[0] }, new[] { 1 }),
};
case 3 when inputShape.Length == 4:
meanCall = (Expr)Tensor.From(mean, new[] { 1, mean.Length, 1, 1 });
stdCall = (Expr)Tensor.From(std, new[] { 1, std.Length, 1, 1 });
break;

default:
meanCall = (Expr)Tensor.From(new float[] { mean[0] }, new[] { 1 });
stdCall = (Expr)Tensor.From(new float[] { std[0] }, new[] { 1 });
break;
}

meanCall.Metadata.OutputNames = new[] { "Mean" };
stdCall.Metadata.OutputNames = new[] { "Std" };
var subMean = (newInput - meanCall).With(metadata: new IRMetadata() { OutputNames = new[] { input.Metadata.OutputNames?[0] + "_SubMean" } });
var divStd = (subMean / stdCall).With(metadata: new IRMetadata() { OutputNames = new[] { input.Metadata.OutputNames?[0] + "_DivStd" } });
newInput = divStd;

// newInput = Binary(BinaryOp.Div, Binary(BinaryOp.Sub, newInput, Tensor.From(mean, new []{1,3,1,1})), Const.FromTensor(std) );
}
Expand Down
5 changes: 3 additions & 2 deletions tests/importer/onnx_/basic/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def _make_module(in_shape, axis, op_version):


in_shapes = [
[1, 3, 16, 16],
[2, 3, 8, 1],
[1, 3, 8, 5],
]

axes = [
Expand All @@ -79,7 +80,7 @@ def _make_module(in_shape, axis, op_version):
op_versions = [
1,
11,
# 13
13
]


Expand Down

0 comments on commit 0a06ba0

Please sign in to comment.