Skip to content

Commit

Permalink
8310268: RISC-V: misaligned memory access in String.Compare intrinsic
Browse files Browse the repository at this point in the history
Backport-of: d6245b6832ccd1da04616e8ba4b90321b2551971
  • Loading branch information
Vladimir Kempik committed Aug 3, 2023
1 parent 53aceba commit 89875df
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 120 deletions.
34 changes: 13 additions & 21 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp
Expand Up @@ -868,9 +868,10 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
// load first parts of strings and finish initialization while loading
{
if (str1_isL == str2_isL) { // LL or UU
// check if str1 and str2 is same pointer
beq(str1, str2, DONE);
// load 8 bytes once to compare
ld(tmp1, Address(str1));
beq(str1, str2, DONE);
ld(tmp2, Address(str2));
mv(t0, STUB_THRESHOLD);
bge(cnt2, t0, STUB);
Expand Down Expand Up @@ -913,9 +914,8 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
addi(cnt1, cnt1, 8);
}
addi(cnt2, cnt2, isUL ? 4 : 8);
bne(tmp1, tmp2, DIFFERENCE);
bgez(cnt2, TAIL);
xorr(tmp3, tmp1, tmp2);
bnez(tmp3, DIFFERENCE);

// main loop
bind(NEXT_WORD);
Expand Down Expand Up @@ -944,38 +944,30 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
addi(cnt1, cnt1, 8);
addi(cnt2, cnt2, 4);
}
bgez(cnt2, TAIL);

xorr(tmp3, tmp1, tmp2);
beqz(tmp3, NEXT_WORD);
j(DIFFERENCE);
bne(tmp1, tmp2, DIFFERENCE);
bltz(cnt2, NEXT_WORD);
bind(TAIL);
xorr(tmp3, tmp1, tmp2);
bnez(tmp3, DIFFERENCE);
// Last longword. In the case where length == 4 we compare the
// same longword twice, but that's still faster than another
// conditional branch.
if (str1_isL == str2_isL) { // LL or UU
ld(tmp1, Address(str1));
ld(tmp2, Address(str2));
load_long_misaligned(tmp1, Address(str1), tmp3, isLL ? 1 : 2);
load_long_misaligned(tmp2, Address(str2), tmp3, isLL ? 1 : 2);
} else if (isLU) { // LU case
lwu(tmp1, Address(str1));
ld(tmp2, Address(str2));
load_int_misaligned(tmp1, Address(str1), tmp3, false);
load_long_misaligned(tmp2, Address(str2), tmp3, 2);
inflate_lo32(tmp3, tmp1);
mv(tmp1, tmp3);
} else { // UL case
lwu(tmp2, Address(str2));
ld(tmp1, Address(str1));
load_int_misaligned(tmp2, Address(str2), tmp3, false);
load_long_misaligned(tmp1, Address(str1), tmp3, 2);
inflate_lo32(tmp3, tmp2);
mv(tmp2, tmp3);
}
bind(TAIL_CHECK);
xorr(tmp3, tmp1, tmp2);
beqz(tmp3, DONE);
beq(tmp1, tmp2, DONE);

// Find the first different characters in the longwords and
// compute their difference.
bind(DIFFERENCE);
xorr(tmp3, tmp1, tmp2);
ctzc_bit(result, tmp3, isLL); // count zero from lsb to msb
srl(tmp1, tmp1, result);
srl(tmp2, tmp2, result);
Expand Down
30 changes: 10 additions & 20 deletions src/hotspot/cpu/riscv/macroAssembler_riscv.cpp
Expand Up @@ -3968,18 +3968,17 @@ void MacroAssembler::ctzc_bit(Register Rd, Register Rs, bool isLL, Register tmp1
void MacroAssembler::inflate_lo32(Register Rd, Register Rs, Register tmp1, Register tmp2) {
assert_different_registers(Rd, Rs, tmp1, tmp2);

mv(tmp1, 0xFF);
mv(Rd, zr);
for (int i = 0; i <= 3; i++) {
mv(tmp1, 0xFF000000); // first byte mask at lower word
andr(Rd, Rs, tmp1);
for (int i = 0; i < 2; i++) {
slli(Rd, Rd, wordSize);
srli(tmp1, tmp1, wordSize);
andr(tmp2, Rs, tmp1);
if (i) {
slli(tmp2, tmp2, i * 8);
}
orr(Rd, Rd, tmp2);
if (i != 3) {
slli(tmp1, tmp1, 8);
}
}
slli(Rd, Rd, wordSize);
andi(tmp2, Rs, 0xFF); // last byte mask at lower word
orr(Rd, Rd, tmp2);
}

// This instruction reads adjacent 4 bytes from the upper half of source register,
Expand All @@ -3988,17 +3987,8 @@ void MacroAssembler::inflate_lo32(Register Rd, Register Rs, Register tmp1, Regis
// Rd: 00A700A600A500A4
void MacroAssembler::inflate_hi32(Register Rd, Register Rs, Register tmp1, Register tmp2) {
assert_different_registers(Rd, Rs, tmp1, tmp2);

mv(tmp1, 0xFF00000000);
mv(Rd, zr);
for (int i = 0; i <= 3; i++) {
andr(tmp2, Rs, tmp1);
orr(Rd, Rd, tmp2);
srli(Rd, Rd, 8);
if (i != 3) {
slli(tmp1, tmp1, 8);
}
}
srli(Rs, Rs, 32); // only upper 32 bits are needed
inflate_lo32(Rd, Rs, tmp1, tmp2);
}

// The size of the blocks erased by the zero_blocks stub. We must
Expand Down
145 changes: 72 additions & 73 deletions src/hotspot/cpu/riscv/stubGenerator_riscv.cpp
Expand Up @@ -2310,24 +2310,21 @@ class StubGenerator: public StubCodeGenerator {
}

// code for comparing 8 characters of strings with Latin1 and Utf16 encoding
void compare_string_8_x_LU(Register tmpL, Register tmpU, Label &DIFF1,
Label &DIFF2) {
const Register strU = x12, curU = x7, strL = x29, tmp = x30;
__ ld(tmpL, Address(strL));
__ addi(strL, strL, 8);
void compare_string_8_x_LU(Register tmpL, Register tmpU, Register strL, Register strU, Label& DIFF) {
const Register tmp = x30, tmpLval = x12;
__ ld(tmpLval, Address(strL));
__ addi(strL, strL, wordSize);
__ ld(tmpU, Address(strU));
__ addi(strU, strU, 8);
__ inflate_lo32(tmp, tmpL);
__ mv(t0, tmp);
__ xorr(tmp, curU, t0);
__ bnez(tmp, DIFF2);

__ ld(curU, Address(strU));
__ addi(strU, strU, 8);
__ inflate_hi32(tmp, tmpL);
__ mv(t0, tmp);
__ xorr(tmp, tmpU, t0);
__ bnez(tmp, DIFF1);
__ addi(strU, strU, wordSize);
__ inflate_lo32(tmpL, tmpLval);
__ xorr(tmp, tmpU, tmpL);
__ bnez(tmp, DIFF);

__ ld(tmpU, Address(strU));
__ addi(strU, strU, wordSize);
__ inflate_hi32(tmpL, tmpLval);
__ xorr(tmp, tmpU, tmpL);
__ bnez(tmp, DIFF);
}

// x10 = result
Expand All @@ -2342,89 +2339,91 @@ class StubGenerator: public StubCodeGenerator {
__ align(CodeEntryAlignment);
StubCodeMark mark(this, "StubRoutines", isLU ? "compare_long_string_different_encoding LU" : "compare_long_string_different_encoding UL");
address entry = __ pc();
Label SMALL_LOOP, TAIL, TAIL_LOAD_16, LOAD_LAST, DIFF1, DIFF2,
DONE, CALCULATE_DIFFERENCE;
const Register result = x10, str1 = x11, cnt1 = x12, str2 = x13, cnt2 = x14,
tmp1 = x28, tmp2 = x29, tmp3 = x30, tmp4 = x7, tmp5 = x31;
RegSet spilled_regs = RegSet::of(tmp4, tmp5);
Label SMALL_LOOP, TAIL, LOAD_LAST, DONE, CALCULATE_DIFFERENCE;
const Register result = x10, str1 = x11, str2 = x13, cnt2 = x14,
tmp1 = x28, tmp2 = x29, tmp3 = x30, tmp4 = x12;

// cnt2 == amount of characters left to compare
// Check already loaded first 4 symbols
__ inflate_lo32(tmp3, isLU ? tmp1 : tmp2);
__ mv(isLU ? tmp1 : tmp2, tmp3);
__ addi(str1, str1, isLU ? wordSize / 2 : wordSize);
__ addi(str2, str2, isLU ? wordSize : wordSize / 2);
__ sub(cnt2, cnt2, 8); // Already loaded 4 symbols. Last 4 is special case.
__ push_reg(spilled_regs, sp);
__ sub(cnt2, cnt2, wordSize / 2); // Already loaded 4 symbols

if (isLU) {
__ add(str1, str1, cnt2);
__ shadd(str2, cnt2, str2, t0, 1);
} else {
__ shadd(str1, cnt2, str1, t0, 1);
__ add(str2, str2, cnt2);
}
__ xorr(tmp3, tmp1, tmp2);
__ mv(tmp5, tmp2);
__ bnez(tmp3, CALCULATE_DIFFERENCE);

Register strU = isLU ? str2 : str1,
strL = isLU ? str1 : str2,
tmpU = isLU ? tmp5 : tmp1, // where to keep U for comparison
tmpL = isLU ? tmp1 : tmp5; // where to keep L for comparison
tmpU = isLU ? tmp2 : tmp1, // where to keep U for comparison
tmpL = isLU ? tmp1 : tmp2; // where to keep L for comparison

__ sub(tmp2, strL, cnt2); // strL pointer to load from
__ slli(t0, cnt2, 1);
__ sub(cnt1, strU, t0); // strU pointer to load from
// make sure main loop is 8 byte-aligned, we should load another 4 bytes from strL
// cnt2 is >= 68 here, no need to check it for >= 0
__ lwu(tmpL, Address(strL));
__ addi(strL, strL, wordSize / 2);
__ ld(tmpU, Address(strU));
__ addi(strU, strU, wordSize);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ bnez(tmp3, CALCULATE_DIFFERENCE);
__ addi(cnt2, cnt2, -wordSize / 2);

__ ld(tmp4, Address(cnt1));
__ addi(cnt1, cnt1, 8);
__ beqz(cnt2, LOAD_LAST); // no characters left except last load
__ sub(cnt2, cnt2, 16);
// we are now 8-bytes aligned on strL
__ sub(cnt2, cnt2, wordSize * 2);
__ bltz(cnt2, TAIL);
__ bind(SMALL_LOOP); // smaller loop
__ sub(cnt2, cnt2, 16);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
__ sub(cnt2, cnt2, wordSize * 2);
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
__ bgez(cnt2, SMALL_LOOP);
__ addi(t0, cnt2, 16);
__ beqz(t0, LOAD_LAST);
__ bind(TAIL); // 1..15 characters left until last load (last 4 characters)
// Address of 8 bytes before last 4 characters in UTF-16 string
__ shadd(cnt1, cnt2, cnt1, t0, 1);
// Address of 16 bytes before last 4 characters in Latin1 string
__ add(tmp2, tmp2, cnt2);
__ ld(tmp4, Address(cnt1, -8));
// last 16 characters before last load
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
__ j(LOAD_LAST);
__ bind(DIFF2);
__ mv(tmpU, tmp4);
__ bind(DIFF1);
__ mv(tmpL, t0);
__ j(CALCULATE_DIFFERENCE);
__ bind(LOAD_LAST);
// Last 4 UTF-16 characters are already pre-loaded into tmp4 by compare_string_8_x_LU.
// No need to load it again
__ mv(tmpU, tmp4);
__ ld(tmpL, Address(strL));
__ addi(t0, cnt2, wordSize * 2);
__ beqz(t0, DONE);
__ bind(TAIL); // 1..15 characters left
// Aligned access. Load bytes in portions - 4, 2, 1.

__ addi(t0, cnt2, wordSize);
__ addi(cnt2, cnt2, wordSize * 2); // amount of characters left to process
__ bltz(t0, LOAD_LAST);
// remaining characters are greater than or equals to 8, we can do one compare_string_8_x_LU
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
__ addi(cnt2, cnt2, -wordSize);
__ beqz(cnt2, DONE); // no character left
__ bind(LOAD_LAST); // cnt2 = 1..7 characters left

__ addi(cnt2, cnt2, -wordSize); // cnt2 is now an offset in strL which points to last 8 bytes
__ slli(t0, cnt2, 1); // t0 is now an offset in strU which points to last 16 bytes
__ add(strL, strL, cnt2); // Address of last 8 bytes in Latin1 string
__ add(strU, strU, t0); // Address of last 16 bytes in UTF-16 string
__ load_int_misaligned(tmpL, Address(strL), t0, false);
__ load_long_misaligned(tmpU, Address(strU), t0, 2);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ beqz(tmp3, DONE);
__ bnez(tmp3, CALCULATE_DIFFERENCE);

__ addi(strL, strL, wordSize / 2); // Address of last 4 bytes in Latin1 string
__ addi(strU, strU, wordSize); // Address of last 8 bytes in UTF-16 string
__ load_int_misaligned(tmpL, Address(strL), t0, false);
__ load_long_misaligned(tmpU, Address(strU), t0, 2);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ bnez(tmp3, CALCULATE_DIFFERENCE);
__ j(DONE); // no character left

// Find the first different characters in the longwords and
// compute their difference.
__ bind(CALCULATE_DIFFERENCE);
__ ctzc_bit(tmp4, tmp3);
__ srl(tmp1, tmp1, tmp4);
__ srl(tmp5, tmp5, tmp4);
__ srl(tmp2, tmp2, tmp4);
__ andi(tmp1, tmp1, 0xFFFF);
__ andi(tmp5, tmp5, 0xFFFF);
__ sub(result, tmp1, tmp5);
__ andi(tmp2, tmp2, 0xFFFF);
__ sub(result, tmp1, tmp2);
__ bind(DONE);
__ pop_reg(spilled_regs, sp);
__ ret();
return entry;
}
Expand Down Expand Up @@ -2537,9 +2536,9 @@ class StubGenerator: public StubCodeGenerator {
__ xorr(tmp4, tmp1, tmp2);
__ bnez(tmp4, DIFF);
__ add(str1, str1, cnt2);
__ ld(tmp5, Address(str1));
__ load_long_misaligned(tmp5, Address(str1), tmp3, isLL ? 1 : 2);
__ add(str2, str2, cnt2);
__ ld(cnt1, Address(str2));
__ load_long_misaligned(cnt1, Address(str2), tmp3, isLL ? 1 : 2);
__ xorr(tmp4, tmp5, cnt1);
__ beqz(tmp4, LENGTH_DIFF);
// Find the first different characters in the longwords and
Expand Down
Expand Up @@ -24,20 +24,20 @@

/*
* @test
* @requires os.arch=="aarch64"
* @requires os.arch=="aarch64" | os.arch=="riscv64"
* @summary String::compareTo implementation uses different algorithms for
* different string length. This test creates string with specified
* size and longer string, which is same at beginning.
* Expecting length delta to be returned. Test class takes 2
* parameters: <string length>, <maximum string length delta>
* Input parameters for this test are set according to Aarch64
* Input parameters for this test are set according to Aarch64/RISC-V
* String::compareTo intrinsic implementation specifics. Aarch64
* implementation has 1, 4, 8 -bytes loops for length < 72 and
* 16, 32, 64 -characters loops for length >= 72. Code is also affected
* 16, 32, 64 -characters loops for length >= 72. Aarch64 Code is also affected
* by SoftwarePrefetchHintDistance vm flag value.
* @run main/othervm -XX:SoftwarePrefetchHintDistance=192 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90 192 193 208 209
* @run main/othervm -XX:SoftwarePrefetchHintDistance=16 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:SoftwarePrefetchHintDistance=-1 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=192 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90 192 193 208 209
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=16 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=-1 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
*/

package compiler.intrinsics.string;
Expand Down

1 comment on commit 89875df

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.