Skip to content

Commit

Permalink
LUCENE-9636: Exact and operation to get a SIMD optimize (apache#2139)
Browse files Browse the repository at this point in the history
Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
  • Loading branch information
2 people authored and epugh@opensourceconnections.com committed Jan 15, 2021
1 parent 5de6681 commit 7529feb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
82 changes: 44 additions & 38 deletions lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -663,25 +663,27 @@ private static void decode5(DataInput in, long[] tmp, long[] longs) throws IOExc
private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 12);
shiftLongs(tmp, 12, longs, 0, 2, MASK8_6);
shiftLongs(tmp, 12, tmp, 0, 0, MASK8_2);
for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 4; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK8_2) << 4;
l0 |= (tmp[tmpIdx+1] & MASK8_2) << 2;
l0 |= (tmp[tmpIdx+2] & MASK8_2) << 0;
long l0 = tmp[tmpIdx+0] << 4;
l0 |= tmp[tmpIdx+1] << 2;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}

private static void decode7(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 14);
shiftLongs(tmp, 14, longs, 0, 1, MASK8_7);
shiftLongs(tmp, 14, tmp, 0, 0, MASK8_1);
for (int iter = 0, tmpIdx = 0, longsIdx = 14; iter < 2; ++iter, tmpIdx += 7, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK8_1) << 6;
l0 |= (tmp[tmpIdx+1] & MASK8_1) << 5;
l0 |= (tmp[tmpIdx+2] & MASK8_1) << 4;
l0 |= (tmp[tmpIdx+3] & MASK8_1) << 3;
l0 |= (tmp[tmpIdx+4] & MASK8_1) << 2;
l0 |= (tmp[tmpIdx+5] & MASK8_1) << 1;
l0 |= (tmp[tmpIdx+6] & MASK8_1) << 0;
long l0 = tmp[tmpIdx+0] << 6;
l0 |= tmp[tmpIdx+1] << 5;
l0 |= tmp[tmpIdx+2] << 4;
l0 |= tmp[tmpIdx+3] << 3;
l0 |= tmp[tmpIdx+4] << 2;
l0 |= tmp[tmpIdx+5] << 1;
l0 |= tmp[tmpIdx+6] << 0;
longs[longsIdx+0] = l0;
}
}
Expand Down Expand Up @@ -766,10 +768,11 @@ private static void decode11(DataInput in, long[] tmp, long[] longs) throws IOEx
private static void decode12(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 24);
shiftLongs(tmp, 24, longs, 0, 4, MASK16_12);
shiftLongs(tmp, 24, tmp, 0, 0, MASK16_4);
for (int iter = 0, tmpIdx = 0, longsIdx = 24; iter < 8; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_4) << 8;
l0 |= (tmp[tmpIdx+1] & MASK16_4) << 4;
l0 |= (tmp[tmpIdx+2] & MASK16_4) << 0;
long l0 = tmp[tmpIdx+0] << 8;
l0 |= tmp[tmpIdx+1] << 4;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}
Expand Down Expand Up @@ -802,37 +805,39 @@ private static void decode13(DataInput in, long[] tmp, long[] longs) throws IOEx
private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 28);
shiftLongs(tmp, 28, longs, 0, 2, MASK16_14);
shiftLongs(tmp, 28, tmp, 0, 0, MASK16_2);
for (int iter = 0, tmpIdx = 0, longsIdx = 28; iter < 4; ++iter, tmpIdx += 7, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_2) << 12;
l0 |= (tmp[tmpIdx+1] & MASK16_2) << 10;
l0 |= (tmp[tmpIdx+2] & MASK16_2) << 8;
l0 |= (tmp[tmpIdx+3] & MASK16_2) << 6;
l0 |= (tmp[tmpIdx+4] & MASK16_2) << 4;
l0 |= (tmp[tmpIdx+5] & MASK16_2) << 2;
l0 |= (tmp[tmpIdx+6] & MASK16_2) << 0;
long l0 = tmp[tmpIdx+0] << 12;
l0 |= tmp[tmpIdx+1] << 10;
l0 |= tmp[tmpIdx+2] << 8;
l0 |= tmp[tmpIdx+3] << 6;
l0 |= tmp[tmpIdx+4] << 4;
l0 |= tmp[tmpIdx+5] << 2;
l0 |= tmp[tmpIdx+6] << 0;
longs[longsIdx+0] = l0;
}
}

private static void decode15(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 30);
shiftLongs(tmp, 30, longs, 0, 1, MASK16_15);
shiftLongs(tmp, 30, tmp, 0, 0, MASK16_1);
for (int iter = 0, tmpIdx = 0, longsIdx = 30; iter < 2; ++iter, tmpIdx += 15, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_1) << 14;
l0 |= (tmp[tmpIdx+1] & MASK16_1) << 13;
l0 |= (tmp[tmpIdx+2] & MASK16_1) << 12;
l0 |= (tmp[tmpIdx+3] & MASK16_1) << 11;
l0 |= (tmp[tmpIdx+4] & MASK16_1) << 10;
l0 |= (tmp[tmpIdx+5] & MASK16_1) << 9;
l0 |= (tmp[tmpIdx+6] & MASK16_1) << 8;
l0 |= (tmp[tmpIdx+7] & MASK16_1) << 7;
l0 |= (tmp[tmpIdx+8] & MASK16_1) << 6;
l0 |= (tmp[tmpIdx+9] & MASK16_1) << 5;
l0 |= (tmp[tmpIdx+10] & MASK16_1) << 4;
l0 |= (tmp[tmpIdx+11] & MASK16_1) << 3;
l0 |= (tmp[tmpIdx+12] & MASK16_1) << 2;
l0 |= (tmp[tmpIdx+13] & MASK16_1) << 1;
l0 |= (tmp[tmpIdx+14] & MASK16_1) << 0;
long l0 = tmp[tmpIdx+0] << 14;
l0 |= tmp[tmpIdx+1] << 13;
l0 |= tmp[tmpIdx+2] << 12;
l0 |= tmp[tmpIdx+3] << 11;
l0 |= tmp[tmpIdx+4] << 10;
l0 |= tmp[tmpIdx+5] << 9;
l0 |= tmp[tmpIdx+6] << 8;
l0 |= tmp[tmpIdx+7] << 7;
l0 |= tmp[tmpIdx+8] << 6;
l0 |= tmp[tmpIdx+9] << 5;
l0 |= tmp[tmpIdx+10] << 4;
l0 |= tmp[tmpIdx+11] << 3;
l0 |= tmp[tmpIdx+12] << 2;
l0 |= tmp[tmpIdx+13] << 1;
l0 |= tmp[tmpIdx+14] << 0;
longs[longsIdx+0] = l0;
}
}
Expand Down Expand Up @@ -1117,10 +1122,11 @@ private static void decode23(DataInput in, long[] tmp, long[] longs) throws IOEx
private static void decode24(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 48);
shiftLongs(tmp, 48, longs, 0, 8, MASK32_24);
shiftLongs(tmp, 48, tmp, 0, 0, MASK32_8);
for (int iter = 0, tmpIdx = 0, longsIdx = 48; iter < 16; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK32_8) << 16;
l0 |= (tmp[tmpIdx+1] & MASK32_8) << 8;
l0 |= (tmp[tmpIdx+2] & MASK32_8) << 0;
long l0 = tmp[tmpIdx+0] << 16;
l0 |= tmp[tmpIdx+1] << 8;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,29 @@
"""

def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
iteration = 1
num_longs = bpv * num_values / remaining_bits_per_long
while num_longs % 2 == 0 and num_values % 2 == 0:
num_longs /= 2
num_values /= 2
iteration *= 2

f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
tmp_idx = 0
b = bpv
b -= remaining_bits_per_long
f.write(' long l0 = tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
tmp_idx += 1
while b >= remaining_bits_per_long:
b -= remaining_bits_per_long
f.write(' l0 |= tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
tmp_idx += 1
f.write(' longs[longsIdx+0] = l0;\n')
f.write(' }\n')


def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
iteration = 1
num_longs = bpv * num_values / remaining_bits_per_long
Expand Down Expand Up @@ -417,7 +440,10 @@ def writeDecode(bpv, f):
o += bpv*2
shift -= bpv
if shift + bpv > 0:
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
if bpv % (next_primitive % bpv) == 0:
writeRemainderWithSIMDOptimize(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
else:
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
f.write(' }\n')
f.write('\n')

Expand Down

0 comments on commit 7529feb

Please sign in to comment.