Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 45 additions & 40 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -556,20 +556,19 @@ public uint HashCore(uint seed, uint mask, in VBuffer<float> values)
return 0;
hash = Hashing.MurmurRound(hash, FloatUtils.GetBits(value == 0 ? 0 : value));
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

private readonly struct HashDouble : IHasher<double>
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]

public uint HashCoreOld(uint seed, uint mask, in double value)
{
if (double.IsNaN(value))
return 0;

return (Hashing.MixHash(HashRound(seed, value)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, true)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -578,7 +577,7 @@ public uint HashCore(uint seed, uint mask, in double value)
if (double.IsNaN(value))
return 0;

return (Hashing.MixHash(HashRound(seed, value), sizeof(double)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, false), sizeof(double)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -589,17 +588,19 @@ public uint HashCore(uint seed, uint mask, in VBuffer<double> values)
{
if (double.IsNaN(value))
return 0;
hash = HashRound(hash, value);
hash = HashRound(hash, value, false);
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(double)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private uint HashRound(uint seed, double value)
private uint HashRound(uint seed, double value, bool old)
{
ulong v = FloatUtils.GetBits(value == 0 ? 0 : value);
var hash = Hashing.MurmurRound(seed, Utils.GetLo(v));
var hi = Utils.GetHi(v);
if (old && hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand Down Expand Up @@ -648,7 +649,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<byte> values)
return 0;
hash = Hashing.MurmurRound(hash, value);
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -672,7 +673,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<ushort> values)
return 0;
hash = Hashing.MurmurRound(hash, value);
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -696,7 +697,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<uint> values)
return 0;
hash = Hashing.MurmurRound(hash, value);
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -707,15 +708,15 @@ public uint HashCoreOld(uint seed, uint mask, in ulong value)
{
if (value == 0)
return 0;
return (Hashing.MixHash(HashRound(seed, value)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, true)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in ulong value)
{
if (value == 0)
return 0;
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, false), sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -726,17 +727,17 @@ public uint HashCore(uint seed, uint mask, in VBuffer<ulong> values)
{
if (value == 0)
return 0;
hash = HashRound(hash, value);
hash = HashRound(hash, value, false);
}
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private uint HashRound(uint seed, ulong value)
private uint HashRound(uint seed, ulong value, bool old)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo(value));
var hi = Utils.GetHi(value);
if (hi == 0)
if (old && hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
Expand All @@ -758,7 +759,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<byte> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -778,7 +779,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<ushort> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -798,7 +799,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<uint> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -807,29 +808,31 @@ public uint HashCore(uint seed, uint mask, in VBuffer<uint> values)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCoreOld(uint seed, uint mask, in ulong value)
{
return (Hashing.MixHash(HashRound(seed, value)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, true)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in ulong value)
{
return (Hashing.MixHash(HashRound(seed, value), sizeof(ulong)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, false), sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in VBuffer<ulong> values)
{
var hash = seed;
foreach (var value in values.DenseValues())
hash = HashRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
hash = HashRound(hash, value, false);
return (Hashing.MixHash(hash, values.Length * sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private uint HashRound(uint seed, ulong value)
private uint HashRound(uint seed, ulong value, bool old)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo(value));
var hi = Utils.GetHi(value);
if (old && hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand All @@ -839,32 +842,32 @@ private uint HashRound(uint seed, ulong value)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCoreOld(uint seed, uint mask, in DataViewRowId value)
{
return (Hashing.MixHash(HashRound(seed, value)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, true)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in DataViewRowId value)
{
return (Hashing.MixHash(HashRound(seed, value), sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, false), 2 * sizeof(ulong)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in VBuffer<DataViewRowId> values)
{
var hash = seed;
foreach (var value in values.DenseValues())
hash = HashRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
hash = HashRound(hash, value, false);
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private uint HashRound(uint seed, DataViewRowId value)
private uint HashRound(uint seed, DataViewRowId value, bool old)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo(value.Low));
var hi = Utils.GetHi(value.Low);
if (hi != 0)
if (old && hi != 0)
hash = Hashing.MurmurRound(hash, hi);
if (value.High != 0)
if (old && value.High != 0)
{
hash = Hashing.MurmurRound(hash, Utils.GetLo(value.High));
hi = Utils.GetHi(value.High);
Expand All @@ -891,7 +894,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<bool> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, value ? 1u : 0u);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -911,7 +914,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<sbyte> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, (uint)value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -931,7 +934,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<short> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, (uint)value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -951,7 +954,7 @@ public uint HashCore(uint seed, uint mask, in VBuffer<int> values)
var hash = seed;
foreach (var value in values.DenseValues())
hash = Hashing.MurmurRound(hash, (uint)value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
return (Hashing.MixHash(hash, values.Length * sizeof(uint)) & mask) + 1;
}
}

Expand All @@ -960,29 +963,31 @@ public uint HashCore(uint seed, uint mask, in VBuffer<int> values)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCoreOld(uint seed, uint mask, in long value)
{
return (Hashing.MixHash(HashRound(seed, value)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, true)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in long value)
{
return (Hashing.MixHash(HashRound(seed, value), sizeof(long)) & mask) + 1;
return (Hashing.MixHash(HashRound(seed, value, false), sizeof(long)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public uint HashCore(uint seed, uint mask, in VBuffer<long> values)
{
var hash = seed;
foreach (var value in values.DenseValues())
hash = HashRound(hash, value);
return (Hashing.MixHash(hash, sizeof(uint)) & mask) + 1;
hash = HashRound(hash, value, false);
return (Hashing.MixHash(hash, values.Length * sizeof(long)) & mask) + 1;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private uint HashRound(uint seed, long value)
private uint HashRound(uint seed, long value, bool old)
{
var hash = Hashing.MurmurRound(seed, Utils.GetLo((ulong)value));
var hi = Utils.GetHi((ulong)value);
if (old && hi == 0)
return hash;
return Hashing.MurmurRound(hash, hi);
}
}
Expand Down
8 changes: 4 additions & 4 deletions test/BaselineOutput/Common/SavePipe/SavePipeHash-Data.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#@ col=VarComb:U4[128]:32-**
#@ }
SingleHash 31 27:Hash9 28:Hash10 29:Hash11 30:Hash12
14 14 14 14 6 0 13 24 47 44 32 16 40 22 24 32 16 40 22 56 59 22 56 59 53 22 56 22 120 0 3 112 33 31 117 22 120 51 31 39 51 31 39 51 31 39 51 31 39
0 0 1 4 0 13 0 32 16 20 52 31 44 24 24 52 31 44 56 59 56 56 59 56 22 56 22 24 123 3 0 112 50 41 36 117 123 75 127 51 47 109 108 51 47 109 108 51 47 109 108 51 47 109 108
14 14 11 4 6 6 0 24 60 20 32 47 44 22 22 24 32 47 44 22 22 56 22 22 56 53 53 22 22 150 6 0 38 22 68 68 68 68
74 3:10 6:6 9:10 12:20 15:22 18:20 21:22 24:22 27:53 31:6 36:35 38:47 43:51 45:22 50:66 52:96 57:66 59:96 64:66 66:96 71:66 73:96
14 14 14 14 6 0 13 24 47 44 8 31 17 22 24 32 16 40 22 56 59 35 23 23 53 22 56 22 120 0 7 112 33 31 117 22 120 51 31 39 51 31 39 51 31 39 17 51 35
0 0 1 4 0 13 0 32 16 20 49 51 54 24 24 52 31 44 56 59 56 23 23 23 22 56 22 24 123 3 7 112 50 41 36 117 123 75 127 51 47 109 108 51 47 109 108 51 47 109 108 17 91 57 49
14 14 11 4 6 6 0 24 60 20 8 6 54 22 22 24 32 47 44 22 22 56 35 35 23 53 53 22 22 150 6 7 38 22 68 68 68 5
74 3:10 6:6 9:10 12:40 15:22 18:20 21:22 24:35 27:53 31:3 36:35 38:47 43:51 45:22 50:66 52:96 57:66 59:96 64:66 66:96 71:2 73:55
40 changes: 21 additions & 19 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using Google.Protobuf;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -1200,29 +1199,32 @@ private class HashData
public uint Value { get; set; }
}

[Fact]
public void MurmurHashKeyTest()
[Theory]
[CombinatorialData]
public void MurmurHashKeyTest(
[CombinatorialValues(/*DataKind.Byte, DataKind.UInt16, */DataKind.UInt32/*, DataKind.UInt64*/)]DataKind keyType)
{
var mlContext = new MLContext();
var dataFile = DeleteOutputPath("KeysToOnnx.txt");
File.WriteAllLines(dataFile,
new[]
{
"2",
"5",
"19"
});

var samples = new[]
var data = ML.Data.LoadFromTextFile(dataFile, new[]
{
new HashData {Value = 232},
new HashData {Value = 42},
new HashData {Value = 0},
};

IDataView data = mlContext.Data.LoadFromEnumerable(samples);
new TextLoader.Column("Value", keyType, new[]
{
new TextLoader.Range(0)
}, new KeyCount(10))
});

var hashEstimator = mlContext.Transforms.Conversion.MapValueToKey("Value").Append(mlContext.Transforms.Conversion.Hash(new[]
{
new HashingEstimator.ColumnOptions(
"ValueHashed",
"Value")
}));
var hashEstimator = ML.Transforms.Conversion.Hash("ValueHashed", "Value");
var model = hashEstimator.Fit(data);
var transformedData = model.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
var onnxModel = ML.Model.ConvertToOnnxProtobuf(model, data);

var onnxFileName = "MurmurHashV2.onnx";
var onnxTextName = "MurmurHashV2.txt";
Expand All @@ -1236,7 +1238,7 @@ public void MurmurHashKeyTest()
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxEstimator = ML.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedColumns<uint>("ValueHashed", "ValueHashed", transformedData, onnxResult);
Expand Down
Loading