From 9dfc6fc882eacf33236ab9ed1c85c8e77db846be Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Mon, 14 Dec 2020 20:37:50 +0800 Subject: [PATCH] LUCENE-9636: Exact and operation to get a SIMD optimize (#2139) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 郭峰 --- .../lucene/codecs/lucene84/ForUtil.java | 82 ++++++++++--------- .../lucene/codecs/lucene84/gen_ForUtil.py | 28 ++++++- 2 files changed, 71 insertions(+), 39 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java index 74b72abb2730..eb07ec18f897 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java @@ -663,10 +663,11 @@ private static void decode5(DataInput in, long[] tmp, long[] longs) throws IOExc private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 12); shiftLongs(tmp, 12, longs, 0, 2, MASK8_6); + shiftLongs(tmp, 12, tmp, 0, 0, MASK8_2); for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 4; ++iter, tmpIdx += 3, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK8_2) << 4; - l0 |= (tmp[tmpIdx+1] & MASK8_2) << 2; - l0 |= (tmp[tmpIdx+2] & MASK8_2) << 0; + long l0 = tmp[tmpIdx+0] << 4; + l0 |= tmp[tmpIdx+1] << 2; + l0 |= tmp[tmpIdx+2] << 0; longs[longsIdx+0] = l0; } } @@ -674,14 +675,15 @@ private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOExc private static void decode7(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 14); shiftLongs(tmp, 14, longs, 0, 1, MASK8_7); + shiftLongs(tmp, 14, tmp, 0, 0, MASK8_1); for (int iter = 0, tmpIdx = 0, longsIdx = 14; iter < 2; ++iter, tmpIdx += 7, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK8_1) << 6; - l0 |= (tmp[tmpIdx+1] & MASK8_1) << 5; - l0 |= (tmp[tmpIdx+2] & MASK8_1) << 4; - l0 |= (tmp[tmpIdx+3] & MASK8_1) << 3; - l0 |= (tmp[tmpIdx+4] & MASK8_1) << 2; - l0 |= (tmp[tmpIdx+5] & MASK8_1) << 1; - l0 |= (tmp[tmpIdx+6] & MASK8_1) << 0; + long l0 = tmp[tmpIdx+0] << 6; + l0 |= tmp[tmpIdx+1] << 5; + l0 |= tmp[tmpIdx+2] << 4; + l0 |= tmp[tmpIdx+3] << 3; + l0 |= tmp[tmpIdx+4] << 2; + l0 |= tmp[tmpIdx+5] << 1; + l0 |= tmp[tmpIdx+6] << 0; longs[longsIdx+0] = l0; } } @@ -766,10 +768,11 @@ private static void decode11(DataInput in, long[] tmp, long[] longs) throws IOEx private static void decode12(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 24); shiftLongs(tmp, 24, longs, 0, 4, MASK16_12); + shiftLongs(tmp, 24, tmp, 0, 0, MASK16_4); for (int iter = 0, tmpIdx = 0, longsIdx = 24; iter < 8; ++iter, tmpIdx += 3, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK16_4) << 8; - l0 |= (tmp[tmpIdx+1] & MASK16_4) << 4; - l0 |= (tmp[tmpIdx+2] & MASK16_4) << 0; + long l0 = tmp[tmpIdx+0] << 8; + l0 |= tmp[tmpIdx+1] << 4; + l0 |= tmp[tmpIdx+2] << 0; longs[longsIdx+0] = l0; } } @@ -802,14 +805,15 @@ private static void decode13(DataInput in, long[] tmp, long[] longs) throws IOEx private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 28); shiftLongs(tmp, 28, longs, 0, 2, MASK16_14); + shiftLongs(tmp, 28, tmp, 0, 0, MASK16_2); for (int iter = 0, tmpIdx = 0, longsIdx = 28; iter < 4; ++iter, tmpIdx += 7, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK16_2) << 12; - l0 |= (tmp[tmpIdx+1] & MASK16_2) << 10; - l0 |= (tmp[tmpIdx+2] & MASK16_2) << 8; - l0 |= (tmp[tmpIdx+3] & MASK16_2) << 6; - l0 |= (tmp[tmpIdx+4] & MASK16_2) << 4; - l0 |= (tmp[tmpIdx+5] & MASK16_2) << 2; - l0 |= (tmp[tmpIdx+6] & MASK16_2) << 0; + long l0 = tmp[tmpIdx+0] << 12; + l0 |= tmp[tmpIdx+1] << 10; + l0 |= tmp[tmpIdx+2] << 8; + l0 |= tmp[tmpIdx+3] << 6; + l0 |= tmp[tmpIdx+4] << 4; + l0 |= tmp[tmpIdx+5] << 2; + l0 |= tmp[tmpIdx+6] << 0; longs[longsIdx+0] = l0; } } @@ -817,22 +821,23 @@ private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOEx private static void decode15(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 30); shiftLongs(tmp, 30, longs, 0, 1, MASK16_15); + shiftLongs(tmp, 30, tmp, 0, 0, MASK16_1); for (int iter = 0, tmpIdx = 0, longsIdx = 30; iter < 2; ++iter, tmpIdx += 15, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK16_1) << 14; - l0 |= (tmp[tmpIdx+1] & MASK16_1) << 13; - l0 |= (tmp[tmpIdx+2] & MASK16_1) << 12; - l0 |= (tmp[tmpIdx+3] & MASK16_1) << 11; - l0 |= (tmp[tmpIdx+4] & MASK16_1) << 10; - l0 |= (tmp[tmpIdx+5] & MASK16_1) << 9; - l0 |= (tmp[tmpIdx+6] & MASK16_1) << 8; - l0 |= (tmp[tmpIdx+7] & MASK16_1) << 7; - l0 |= (tmp[tmpIdx+8] & MASK16_1) << 6; - l0 |= (tmp[tmpIdx+9] & MASK16_1) << 5; - l0 |= (tmp[tmpIdx+10] & MASK16_1) << 4; - l0 |= (tmp[tmpIdx+11] & MASK16_1) << 3; - l0 |= (tmp[tmpIdx+12] & MASK16_1) << 2; - l0 |= (tmp[tmpIdx+13] & MASK16_1) << 1; - l0 |= (tmp[tmpIdx+14] & MASK16_1) << 0; + long l0 = tmp[tmpIdx+0] << 14; + l0 |= tmp[tmpIdx+1] << 13; + l0 |= tmp[tmpIdx+2] << 12; + l0 |= tmp[tmpIdx+3] << 11; + l0 |= tmp[tmpIdx+4] << 10; + l0 |= tmp[tmpIdx+5] << 9; + l0 |= tmp[tmpIdx+6] << 8; + l0 |= tmp[tmpIdx+7] << 7; + l0 |= tmp[tmpIdx+8] << 6; + l0 |= tmp[tmpIdx+9] << 5; + l0 |= tmp[tmpIdx+10] << 4; + l0 |= tmp[tmpIdx+11] << 3; + l0 |= tmp[tmpIdx+12] << 2; + l0 |= tmp[tmpIdx+13] << 1; + l0 |= tmp[tmpIdx+14] << 0; longs[longsIdx+0] = l0; } } @@ -1117,10 +1122,11 @@ private static void decode23(DataInput in, long[] tmp, long[] longs) throws IOEx private static void decode24(DataInput in, long[] tmp, long[] longs) throws IOException { in.readLELongs(tmp, 0, 48); shiftLongs(tmp, 48, longs, 0, 8, MASK32_24); + shiftLongs(tmp, 48, tmp, 0, 0, MASK32_8); for (int iter = 0, tmpIdx = 0, longsIdx = 48; iter < 16; ++iter, tmpIdx += 3, longsIdx += 1) { - long l0 = (tmp[tmpIdx+0] & MASK32_8) << 16; - l0 |= (tmp[tmpIdx+1] & MASK32_8) << 8; - l0 |= (tmp[tmpIdx+2] & MASK32_8) << 0; + long l0 = tmp[tmpIdx+0] << 16; + l0 |= tmp[tmpIdx+1] << 8; + l0 |= tmp[tmpIdx+2] << 0; longs[longsIdx+0] = l0; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py index 260a6834ab37..94f31e24a95b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py @@ -366,6 +366,29 @@ """ +def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, o, num_values, f): + iteration = 1 + num_longs = bpv * num_values / remaining_bits_per_long + while num_longs % 2 == 0 and num_values % 2 == 0: + num_longs /= 2 + num_values /= 2 + iteration *= 2 + + f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long)) + f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values)) + tmp_idx = 0 + b = bpv + b -= remaining_bits_per_long + f.write(' long l0 = tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b)) + tmp_idx += 1 + while b >= remaining_bits_per_long: + b -= remaining_bits_per_long + f.write(' l0 |= tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b)) + tmp_idx += 1 + f.write(' longs[longsIdx+0] = l0;\n') + f.write(' }\n') + + def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, f): iteration = 1 num_longs = bpv * num_values / remaining_bits_per_long @@ -417,7 +440,10 @@ def writeDecode(bpv, f): o += bpv*2 shift -= bpv if shift + bpv > 0: - writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f) + if bpv % (next_primitive % bpv) == 0: + writeRemainderWithSIMDOptimize(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f) + else: + writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f) f.write(' }\n') f.write('\n')