Skip to content

Commit

Permalink
GNNE-1974: Feature/non uniform (#1128)
Browse files Browse the repository at this point in the history
* add split

* fix tests
  • Loading branch information
zhen8838 committed Nov 16, 2023
1 parent 734cce0 commit 050de8e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Expr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public DataType CheckedDataType
{
case TensorType type:
return type.DType;
case DistributedType type:
return type.TensorType.DType;
default:
if (DumpScope.Current.IsEnabled(DumpFlags.Compile))
{
Expand Down
93 changes: 78 additions & 15 deletions src/Nncase.Core/Utilities/DistributedUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens
var ndsbp = new List<SBP>();
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]))
{
ndsbp.Add(SBP.S(axis));
}
Expand All @@ -28,7 +28,7 @@ public static IReadOnlyList<IRArray<SBP>> GetLeafCandidateNDSBPs(TensorType tens

return ndsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}
Expand All @@ -53,7 +53,7 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT
candidateNdsbps[i].Add(SBP.B);
for (int axis = 0; axis < tensorType.Shape.Rank; axis++)
{
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis))
{
candidateNdsbps[i].Add(SBP.S(axis));
}
Expand All @@ -67,38 +67,101 @@ public static IReadOnlyList<IRArray<SBP>> GetPartialCandidateNDSBPs(DistributedT

return candidateNdsbps.CartesianProduct().
Select(ndsbp => ndsbp.ToArray()).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)).
Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).
Select(ndsbp => new IRArray<SBP>(ndsbp)).
ToArray();
}

public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement, [MaybeNullWhen(false)] out TensorType distType)
public static bool IsDistributable(TensorType tensorType, ReadOnlySpan<SBP> ndsbp, Placement placement)
{
distType = null;
if (!tensorType.Shape.IsFixed)
{
return false;
}

var shape = tensorType.Shape.ToValueArray();
for (int i = 0; i < ndsbp.Length; i++)
var divisors = GetDivisors(new DistributedType(tensorType, new IRArray<SBP>(ndsbp.ToArray()), placement));
return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : IsDivideBy(tensorType.Shape[p.axis].FixedValue, p.d));
}

public static IReadOnlyList<int> GetDivisors(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();
var divisors = Enumerable.Repeat(0, shape.Length).ToArray();
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (ndsbp[i] is SBPSplit { Axis: int axis })
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
if (!IsDivisible(shape[axis], placement.Hierarchy[i]))
if (divisors[axis] == 0)
{
return false;
divisors[axis] = 1;
}

shape[axis] /= placement.Hierarchy[i];
divisors[axis] *= distributedType.Placement.Hierarchy[i];
}
}

distType = tensorType with { Shape = shape };
return true;
return divisors;
}

public static bool TryGetDividedTensorType(DistributedType distributedType, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TensorType tensorType)
{
tensorType = null;
var divisors = GetDivisors(distributedType);
if (divisors.Select((d, i) => (d, i)).All(p => p.d == 0 || IsDivideExactly(distributedType.TensorType.Shape[p.i].FixedValue, p.d)))
{
tensorType = new TensorType(distributedType.TensorType.DType, distributedType.TensorType.Shape.Zip(divisors).Select(p => p.Second == 0 ? p.First.FixedValue : p.First.FixedValue / p.Second).ToArray());
return true;
}

return false;
}

public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedType)
{
var shape = distributedType.TensorType.Shape.ToValueArray();
var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List<int>()).ToArray();
var ids = distributedType.Placement.Name.Select(c => new Var(c + "id", TensorType.Scalar(DataTypes.Int32))).ToArray();
var hierarchyStrides = TensorUtilities.GetStrides(distributedType.Placement.Hierarchy.ToArray());
for (int i = 0; i < distributedType.NdSBP.Count; i++)
{
if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis })
{
hierarchies[axis].Add(i);
}
}

return hierarchies.Select((divs, axis) =>
{
Expr dim;
if (divs.Any())
{
var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray());
var (res, rem) = Math.DivRem(shape[axis], divsor);
dim = IR.F.Math.Select(
TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1),
res,
res + rem);
}
else
{
dim = distributedType.TensorType.Shape[axis].FixedValue;
}
return dim;
}).ToArray();
}

public static bool IsDivideBy(int input, int divisor)
{
if (input >= divisor)
{
return true;
}

return false;
}

public static bool IsDivisible(int input, int divisor)
public static bool IsDivideExactly(int input, int divisor)
{
if (input >= divisor && input % divisor == 0)
{
Expand Down

0 comments on commit 050de8e

Please sign in to comment.